Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
8548d97306 handle contig on after 2026-04-30 21:13:54 +00:00
2 changed files with 37 additions and 8 deletions

View file

@ -321,9 +321,32 @@ class TestCustomKernel(unittest.TestCase):
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(z.tolist(), x.add(2).tolist())
@unittest.expectedFailure
def test_custom_kernel_sched_copy(self): self.test_custom_kernel_sched(use_custom=True)
def test_custom_kernel_sched_contiguous_view(self):
x = Tensor.arange(64).realize()
y = Tensor.empty_like(x)
y = Tensor.custom_kernel(y, x, fxn=custom_add_one_kernel)[0]
ys = y.shrink(((4, 36),))
z = Tensor.empty_like(ys)
z = Tensor.custom_kernel(z, ys, fxn=custom_add_one_kernel)[0]
GlobalCounters.reset()
z.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(z.tolist(), list(range(6, 38)))
def test_custom_kernel_sched_reshape_view(self):
x = Tensor.arange(64).realize()
y = Tensor.empty_like(x)
y = Tensor.custom_kernel(y, x, fxn=custom_add_one_kernel)[0]
yr = y.reshape(8, 8)
z = Tensor.empty_like(yr)
z = Tensor.custom_kernel(z, yr, fxn=custom_add_one_kernel)[0]
GlobalCounters.reset()
z.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(z.tolist(), x.add(2).reshape(8, 8).tolist())
class TestUOpReduce(unittest.TestCase):
def test_uop_sum(self):
a = Tensor([1.0, 2, 3, 4, 5])

View file

@ -63,10 +63,16 @@ def _make_buffer_view(src:UOp) -> UOp|None:
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.numel(), offset)).reshape(src.shape)
def contiguous_mops_to_view(c:UOp, src:UOp):
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
buf = src.base
if buf.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
if src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
"""CONTIGUOUS(MOPS(BUFFER/AFTER)) -> BUFFER_VIEW when movement ops collapse to a contiguous range."""
# src.base only locates AFTER deps; _make_buffer_view still validates the full movement chain.
base = src.base
deps = ()
while base.op is Ops.AFTER:
deps = base.src[1:] + deps
src = src.substitute({base:base.src[0]})
base = src.base
if base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
if not deps and src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
# no symbolic shape
if not all_int(c.shape): return None
@ -83,11 +89,11 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view")
if resolved.op is not Ops.MULTI: return None
if (view := _make_buffer_view(resolved.src[0])) is None: return None
return view.multi(resolved.arg).contiguous(tag=c.tag)
ret = view.multi(resolved.arg)
elif (ret := _make_buffer_view(src)) is None: return None
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
if (view := _make_buffer_view(src)) is None: return None
return view.contiguous(tag=c.tag)
return ret.after(*deps, tag=c.tag) if deps else ret.contiguous(tag=c.tag)
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None