mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
contig_on_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8548d97306 |
2 changed files with 37 additions and 8 deletions
|
|
@ -321,9 +321,32 @@ class TestCustomKernel(unittest.TestCase):
|
|||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(z.tolist(), x.add(2).tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_custom_kernel_sched_copy(self): self.test_custom_kernel_sched(use_custom=True)
|
||||
|
||||
def test_custom_kernel_sched_contiguous_view(self):
|
||||
x = Tensor.arange(64).realize()
|
||||
y = Tensor.empty_like(x)
|
||||
y = Tensor.custom_kernel(y, x, fxn=custom_add_one_kernel)[0]
|
||||
ys = y.shrink(((4, 36),))
|
||||
z = Tensor.empty_like(ys)
|
||||
z = Tensor.custom_kernel(z, ys, fxn=custom_add_one_kernel)[0]
|
||||
GlobalCounters.reset()
|
||||
z.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(z.tolist(), list(range(6, 38)))
|
||||
|
||||
def test_custom_kernel_sched_reshape_view(self):
|
||||
x = Tensor.arange(64).realize()
|
||||
y = Tensor.empty_like(x)
|
||||
y = Tensor.custom_kernel(y, x, fxn=custom_add_one_kernel)[0]
|
||||
yr = y.reshape(8, 8)
|
||||
z = Tensor.empty_like(yr)
|
||||
z = Tensor.custom_kernel(z, yr, fxn=custom_add_one_kernel)[0]
|
||||
GlobalCounters.reset()
|
||||
z.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(z.tolist(), x.add(2).reshape(8, 8).tolist())
|
||||
|
||||
class TestUOpReduce(unittest.TestCase):
|
||||
def test_uop_sum(self):
|
||||
a = Tensor([1.0, 2, 3, 4, 5])
|
||||
|
|
|
|||
|
|
@ -63,10 +63,16 @@ def _make_buffer_view(src:UOp) -> UOp|None:
|
|||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.numel(), offset)).reshape(src.shape)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
buf = src.base
|
||||
if buf.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
if src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
"""CONTIGUOUS(MOPS(BUFFER/AFTER)) -> BUFFER_VIEW when movement ops collapse to a contiguous range."""
|
||||
# src.base only locates AFTER deps; _make_buffer_view still validates the full movement chain.
|
||||
base = src.base
|
||||
deps = ()
|
||||
while base.op is Ops.AFTER:
|
||||
deps = base.src[1:] + deps
|
||||
src = src.substitute({base:base.src[0]})
|
||||
base = src.base
|
||||
if base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
if not deps and src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
|
||||
# no symbolic shape
|
||||
if not all_int(c.shape): return None
|
||||
|
|
@ -83,11 +89,11 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
|||
resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view")
|
||||
if resolved.op is not Ops.MULTI: return None
|
||||
if (view := _make_buffer_view(resolved.src[0])) is None: return None
|
||||
return view.multi(resolved.arg).contiguous(tag=c.tag)
|
||||
ret = view.multi(resolved.arg)
|
||||
elif (ret := _make_buffer_view(src)) is None: return None
|
||||
|
||||
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
|
||||
if (view := _make_buffer_view(src)) is None: return None
|
||||
return view.contiguous(tag=c.tag)
|
||||
return ret.after(*deps, tag=c.tag) if deps else ret.contiguous(tag=c.tag)
|
||||
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue