realized PYTHON copies (#14934)

* realized PYTHON copies

* comment that out

* fix that test

* append afters

* contig

* disk copies

* should be 124

* 332
This commit is contained in:
George Hotz 2026-02-21 20:29:31 +08:00 committed by GitHub
commit 8ef5544e4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 21 deletions

View file

@ -23,7 +23,8 @@ if __name__ == "__main__":
kernel_count = GlobalCounters.kernel_count
assert kernel_count > 0, "No kernels, test failed"
expected_kernels = 228
# NOTE: this is 124 on torch 2.10.0
expected_kernels = 332
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
assert kernel_count <= expected_kernels, f"{expectation}"

View file

@ -704,7 +704,7 @@ class TestMultiTensor(unittest.TestCase):
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7)).schedule()
t0.reshape((26*15,7)).contiguous().schedule()
# it doesn't work like this anymore
# NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph
@ -897,18 +897,18 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
with self.assertRaises(AssertionError):
# sharded axis shrink on non-device boundry is not allowed
a = t.shrink(((0, 3), (0, 8)))
a = t.shrink(((0, 3), (0, 8))).contiguous()
a.schedule()
a = t.shrink(((0, 2), (2, 4)))
assert a.shape == (2, 2)
ref = Tensor.arange(64).reshape(8, 8).shrink(((0, 2), (2, 4)))
np.testing.assert_equal(a.numpy(), ref.numpy())
a = t.shrink(((0, 2), (0, 8)))
a = t.shrink(((0, 2), (0, 8))).contiguous()
a.schedule()
assert a.shape == (2, 8)
p = a.pad(((0, 6), (0, 0)))
p = a.pad(((0, 6), (0, 0))).contiguous()
p.schedule()
assert p.shape == (8, 8)

View file

@ -757,11 +757,11 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]])
helper_test_op([()], lambda x: x.rsqrt())
@unittest.skipIf(TINY_BACKEND, "broken on tiny backend, not sure why")
def test_xor(self):
data = [[1,-8,1],[32,1,6]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.int32)
# NOTE: this breaks assigns because it's folded to 0!
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)

View file

@ -1235,7 +1235,7 @@ class TestView(unittest.TestCase):
bv = b.pad(((0, 2),))[-2:]
# this becomes a late a*0
late_mul = a*bv
run_schedule(check_schedule(late_mul, 1))
run_schedule(check_schedule(late_mul, 2))
# the arange doesn't realize
#self.assertIsNone(b.uop.base.realized)
# mul doesn't realize

View file

@ -1,6 +1,6 @@
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, _remove_all_tags, identity_element
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, DEBUG, argsort
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
def tag_uop(ctx:tuple[list[UOp], dict[UOp, UOp], set[UOp]], x:UOp):
if x.tag is not None: return None
@ -12,7 +12,7 @@ def disk_copy_is_buffer(ctx, u):
to_disk = isinstance(u._device, str) and u._device.startswith("DISK")
if to_disk: ctx[1][u] = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape)
# all copies from disk/numpy are realized into a real buffer
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK"])
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON"])
if from_creation: return tag_uop(ctx, u)
def apply_after(ctx, u):
@ -82,7 +82,25 @@ pm_early_transform_tensor_graph = PatternMatcher([
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
])
pm_remove_unique_consts = PatternMatcher([
def untag_and_append(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
if x.tag is None: return None
uop_list, buffer_map, assigns = ctx
ret = x.replace(tag=None)
for t in x.tag:
original_uop: UOp = uop_list[t]
replace_uop = ret
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
assigns.append(ret)
return ret
def append_after(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
ctx[2].append(x)
pm_finalize_call = PatternMatcher([
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
(UPat(Ops.AFTER, name="x"), append_after),
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# replace UNIQUE with LUNIQUE for CONST cache key normalization
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
])
@ -104,13 +122,8 @@ def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
big_sink = graph_rewrite(big_sink, pm_early_transform_tensor_graph, ctx={}, name="early transform tensor graph")
# here we construct the final buffer_map. this is everything that will go into the tensor map
for s in big_sink.toposort():
if s.tag is not None:
assert s.op is Ops.ASSIGN
for t in s.tag:
original_uop = uop_list[t]
replace_uop = s
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
big_sink = graph_rewrite(big_sink, _remove_all_tags+pm_remove_unique_consts, name="remove tags")
return big_sink, buffer_map
assigns: list[UOp] = []
graph_rewrite(big_sink, pm_finalize_call, ctx=(uop_list, buffer_map, assigns), name="finalize call")
ret = UOp.sink(*assigns)
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="*** Call")
return ret, buffer_map