fix symbolic contiguous_view_offset (#15749)

* fix symbolic contiguous_view_offset

* flatten
This commit is contained in:
chenyu 2026-04-15 16:54:38 -04:00 committed by GitHub
commit 507c02cecb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View file

@ -97,5 +97,16 @@ class TestSymbolicShrink(unittest.TestCase):
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
assert t.shape == (2, 1)
class TestSymbolicContiguousViewOffset(unittest.TestCase):
def test_shrink_from_start(self):
v = Variable("v", 1, 10).bind(5)
t = Tensor.rand(10).realize().shrink(((0, v),))
self.assertEqual(t.uop.contiguous_view_offset(), 0)
def test_shrink_with_offset(self):
v = Variable("v", 1, 7).bind(4)
t = Tensor.rand(10).realize().shrink(((3, 3+v),))
self.assertEqual(t.uop.contiguous_view_offset(), 3)
if __name__ == '__main__':
unittest.main()

View file

@ -690,9 +690,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
"""If movement ops on a BUFFER collapse to a contiguous range, return `offset` in elements. Otherwise None."""
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.uop.symbolic import symbolic
out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset")
numel = self.numel()
out = graph_rewrite(self.flatten().index(UOp.range(numel, 0)), pm_mops+symbolic, name="contiguous_view_offset")
if out.op is not Ops.INDEX: return None
if out.src[1].op is Ops.CONST and self.size == 1:
if out.src[1].op is Ops.CONST and resolve(numel == 1, False):
if not isinstance(out.src[1].arg, int): return None # masked/padded regions produce InvalidType
return out.src[1].arg
if out.src[1].op is Ops.RANGE: return 0