mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix symbolic contiguous_view_offset (#15749)
* fix symbolic contiguous_view_offset * flatten
This commit is contained in:
parent
164495678c
commit
507c02cecb
2 changed files with 14 additions and 2 deletions
|
|
@ -97,5 +97,16 @@ class TestSymbolicShrink(unittest.TestCase):
|
|||
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
|
||||
assert t.shape == (2, 1)
|
||||
|
||||
class TestSymbolicContiguousViewOffset(unittest.TestCase):
|
||||
def test_shrink_from_start(self):
|
||||
v = Variable("v", 1, 10).bind(5)
|
||||
t = Tensor.rand(10).realize().shrink(((0, v),))
|
||||
self.assertEqual(t.uop.contiguous_view_offset(), 0)
|
||||
|
||||
def test_shrink_with_offset(self):
|
||||
v = Variable("v", 1, 7).bind(4)
|
||||
t = Tensor.rand(10).realize().shrink(((3, 3+v),))
|
||||
self.assertEqual(t.uop.contiguous_view_offset(), 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -690,9 +690,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
"""If movement ops on a BUFFER collapse to a contiguous range, return `offset` in elements. Otherwise None."""
|
||||
from tinygrad.schedule.rangeify import pm_mops
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset")
|
||||
numel = self.numel()
|
||||
out = graph_rewrite(self.flatten().index(UOp.range(numel, 0)), pm_mops+symbolic, name="contiguous_view_offset")
|
||||
if out.op is not Ops.INDEX: return None
|
||||
if out.src[1].op is Ops.CONST and self.size == 1:
|
||||
if out.src[1].op is Ops.CONST and resolve(numel == 1, False):
|
||||
if not isinstance(out.src[1].arg, int): return None # masked/padded regions produce InvalidType
|
||||
return out.src[1].arg
|
||||
if out.src[1].op is Ops.RANGE: return 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue