mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
realized PYTHON copies (#14934)
* realized PYTHON copies * comment that out * fix that test * append afters * contig * disk copies * should be 124 * 332
This commit is contained in:
parent
cf23c2eee7
commit
8ef5544e4a
5 changed files with 35 additions and 21 deletions
|
|
@ -23,7 +23,8 @@ if __name__ == "__main__":
|
|||
|
||||
kernel_count = GlobalCounters.kernel_count
|
||||
assert kernel_count > 0, "No kernels, test failed"
|
||||
expected_kernels = 228
|
||||
# NOTE: this is 124 on torch 2.10.0
|
||||
expected_kernels = 332
|
||||
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
|
||||
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
assert kernel_count <= expected_kernels, f"{expectation}"
|
||||
|
|
@ -704,7 +704,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
|
||||
# test no left join
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7)).schedule()
|
||||
t0.reshape((26*15,7)).contiguous().schedule()
|
||||
|
||||
# it doesn't work like this anymore
|
||||
# NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph
|
||||
|
|
@ -897,18 +897,18 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
|||
|
||||
with self.assertRaises(AssertionError):
|
||||
# sharded axis shrink on non-device boundry is not allowed
|
||||
a = t.shrink(((0, 3), (0, 8)))
|
||||
a = t.shrink(((0, 3), (0, 8))).contiguous()
|
||||
a.schedule()
|
||||
a = t.shrink(((0, 2), (2, 4)))
|
||||
assert a.shape == (2, 2)
|
||||
ref = Tensor.arange(64).reshape(8, 8).shrink(((0, 2), (2, 4)))
|
||||
np.testing.assert_equal(a.numpy(), ref.numpy())
|
||||
|
||||
a = t.shrink(((0, 2), (0, 8)))
|
||||
a = t.shrink(((0, 2), (0, 8))).contiguous()
|
||||
a.schedule()
|
||||
assert a.shape == (2, 8)
|
||||
|
||||
p = a.pad(((0, 6), (0, 0)))
|
||||
p = a.pad(((0, 6), (0, 0))).contiguous()
|
||||
p.schedule()
|
||||
assert p.shape == (8, 8)
|
||||
|
||||
|
|
|
|||
|
|
@ -757,11 +757,11 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]])
|
||||
helper_test_op([()], lambda x: x.rsqrt())
|
||||
|
||||
@unittest.skipIf(TINY_BACKEND, "broken on tiny backend, not sure why")
|
||||
def test_xor(self):
|
||||
data = [[1,-8,1],[32,1,6]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.int32)
|
||||
# NOTE: this breaks assigns because it's folded to 0!
|
||||
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
|
|
|||
|
|
@ -1235,7 +1235,7 @@ class TestView(unittest.TestCase):
|
|||
bv = b.pad(((0, 2),))[-2:]
|
||||
# this becomes a late a*0
|
||||
late_mul = a*bv
|
||||
run_schedule(check_schedule(late_mul, 1))
|
||||
run_schedule(check_schedule(late_mul, 2))
|
||||
# the arange doesn't realize
|
||||
#self.assertIsNone(b.uop.base.realized)
|
||||
# mul doesn't realize
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, _remove_all_tags, identity_element
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], dict[UOp, UOp], set[UOp]], x:UOp):
|
||||
if x.tag is not None: return None
|
||||
|
|
@ -12,7 +12,7 @@ def disk_copy_is_buffer(ctx, u):
|
|||
to_disk = isinstance(u._device, str) and u._device.startswith("DISK")
|
||||
if to_disk: ctx[1][u] = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape)
|
||||
# all copies from disk/numpy are realized into a real buffer
|
||||
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK"])
|
||||
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON"])
|
||||
if from_creation: return tag_uop(ctx, u)
|
||||
|
||||
def apply_after(ctx, u):
|
||||
|
|
@ -82,7 +82,25 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
|||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
|
||||
])
|
||||
|
||||
pm_remove_unique_consts = PatternMatcher([
|
||||
def untag_and_append(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
|
||||
if x.tag is None: return None
|
||||
uop_list, buffer_map, assigns = ctx
|
||||
ret = x.replace(tag=None)
|
||||
for t in x.tag:
|
||||
original_uop: UOp = uop_list[t]
|
||||
replace_uop = ret
|
||||
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
|
||||
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
|
||||
assigns.append(ret)
|
||||
return ret
|
||||
|
||||
def append_after(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
|
||||
ctx[2].append(x)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
|
||||
(UPat(Ops.AFTER, name="x"), append_after),
|
||||
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
# replace UNIQUE with LUNIQUE for CONST cache key normalization
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
|
||||
])
|
||||
|
|
@ -104,13 +122,8 @@ def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
|||
big_sink = graph_rewrite(big_sink, pm_early_transform_tensor_graph, ctx={}, name="early transform tensor graph")
|
||||
|
||||
# here we construct the final buffer_map. this is everything that will go into the tensor map
|
||||
for s in big_sink.toposort():
|
||||
if s.tag is not None:
|
||||
assert s.op is Ops.ASSIGN
|
||||
for t in s.tag:
|
||||
original_uop = uop_list[t]
|
||||
replace_uop = s
|
||||
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
|
||||
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
|
||||
big_sink = graph_rewrite(big_sink, _remove_all_tags+pm_remove_unique_consts, name="remove tags")
|
||||
return big_sink, buffer_map
|
||||
assigns: list[UOp] = []
|
||||
graph_rewrite(big_sink, pm_finalize_call, ctx=(uop_list, buffer_map, assigns), name="finalize call")
|
||||
ret = UOp.sink(*assigns)
|
||||
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="*** Call")
|
||||
return ret, buffer_map
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue