tinygrad/test/backend/test_schedule.py

476 lines
19 KiB
Python

# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
import unittest
import numpy as np
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.uop.ops import UOp, Ops, UPat
from tinygrad.helpers import DEBUG, DEV, GlobalCounters, Context, all_same, temp
from tinygrad.engine.realize import compile_linear, run_linear
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
if to_prerealize:
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
if isinstance(t, Tensor): linear, var_vals = t.linear_with_vars()
elif isinstance(t, list) and isinstance(t[0], Tensor): linear, var_vals = Tensor.linear_with_vars(*t)
else:
assert isinstance(t, UOp), f"can't schedule {t}"
linear, var_vals = Tensor(t).linear_with_vars()
kernel_cnt = sum((len(call.device) if isinstance(call.device, tuple) else 1)
for call in linear.src if call.src[0].op is Ops.SINK or not filter_sink)
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
if DEBUG >= 3:
for i,call in enumerate(linear.src):
print("kernel", i+1)
print(call.src[0])
raise KernelCountException(f"{kernel_cnt} != {allowed}")
# test compiling the linear
compile_linear(linear)
return linear, var_vals
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()
class TestSchedule(unittest.TestCase):
def setUp(self):
self.ctx = Context(SPLIT_REDUCEOP=0)
self.ctx.__enter__()
def tearDown(self):
self.ctx.__exit__(None, None, None)
@unittest.skip("no longer supported")
def test_double_from(self):
x = Tensor([1,2,3,4])
out = x.to('python')
check_schedule(out, 0, filter_sink=False)
def test_example_matmul_same(self):
x = Tensor.eye(64).clone().realize()
z = x.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_linear(*check_schedule(out, 1))
# NOTE: the gradient flows twice
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
def test_pad_reduce_scope_collision(self):
b = Tensor.rand(4, 3).realize()
s1 = b.pad(((1, 1), (0, 0))).sum(axis=1)
s2 = b.pad(((1, 2), (0, 0))).shrink(((0, 6), (0, 3))).sum(axis=1)
out = s1 + s2
run_linear(*check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), 2*np.pad(b.numpy(), ((1, 1), (0, 0))).sum(axis=1), rtol=1e-6)
def test_cumsum_parallel_reduce_fused(self):
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
step, num_steps = 513, 10
t = Tensor.arange(step).float().realize()
phase = t.cumsum()
tiled = phase.repeat((num_steps,)).reshape(num_steps, step)
pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)
out = (tiled * pattern).flatten()
expected = np.tile(np.arange(step).astype(np.float32).cumsum(), num_steps).reshape(num_steps, step)
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
def test_reduce_different_nesting_depth(self):
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
x = Tensor.arange(768).reshape(3, 256).float()
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
def test_fuse_assign_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)
run_linear(*check_schedule(x.shrink((None, (0, 2))).assign(a.clone()), 2))
np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]])
def test_assign_non_contiguous_alt(self): self.test_assign_non_contiguous(alt=True)
def test_assign_non_contiguous(self, alt=False):
x = (Tensor.arange(16)-100).reshape(4,4).clone().realize()
xref = x.numpy()
if alt:
y = Tensor.randint(2, 4).contiguous().realize()
a = Tensor.arange(8).reshape(2, 4)+y
tst = x.shrink(((0, 2), None)).assign(a).realize()
xref[:2, :] = np.arange(8).reshape(2, 4)+y.numpy()
else:
y = Tensor.randint(4, 2).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)+y
tst = x.shrink((None, (0, 2))).assign(a).realize()
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
np.testing.assert_equal(x.numpy(), xref)
np.testing.assert_equal(tst.numpy(), a.numpy())
def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1):
a = Tensor.arange(16).reshape(4, 4).clone(device="CPU").realize()
a2 = mop(a)
expected = (a+a2).tolist()
a.assign(a+a2)
linear, var_vals = a.linear_with_vars()
kcount = len(linear.src)
run_linear(linear, var_vals)
self.assertListEqual(a.tolist(), expected)
self.assertEqual(kcount, expected_kcount)
def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2)
def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1)
def test_setitem_const_fused(self):
# https://github.com/tinygrad/tinygrad/issues/10690
a = Tensor.arange(16).clone().realize()
GlobalCounters.reset()
a[4] = 3
self.assertEqual(GlobalCounters.kernel_count, 0)
a.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(a.tolist(), [0, 1, 2, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
def test_no_extra_contiguous_on_setitem_assign_back(self):
# pattern: contiguous copy, advanced setitem, assign back (e.g. torch backend _view_write)
base = Tensor.arange(16).reshape(4, 4).clone()
flat_base = base.reshape(16).contiguous()
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
flat_base[idx] = Tensor([99,99,99,99])
base.assign(flat_base.reshape(4, 4))
sched = check_schedule(base, 4)
run_linear(*sched)
expected = list(range(16))
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
np.testing.assert_equal(base.reshape(16).numpy(), expected)
def test_const_folding_alt(self):
t = Tensor.full((2,), 1.)
lt = (t < 0.)
a = Tensor.empty(2).assign(t*lt.where(-1., 0.))
b = Tensor.empty(2, dtype=dtypes.bool).assign(lt)
Tensor.realize(a, b)
self.assertEqual(a.tolist(), [0., 0.])
self.assertEqual(b.tolist(), [False, False])
def test_self_assign_no_empty_kernel(self):
for shape in [(3, 3), (4, 4)]:
a = Tensor.ones(*shape).contiguous().realize()
a.assign(a / 1)
run_linear(*check_schedule(a, 0, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*shape[1]]*shape[0])
def test_deviceless_materialize_localizes_to_target(self):
dev = "CPU" if Device.DEFAULT != "CPU" else "CPU:1"
t = Tensor.arange(Variable("s", 1, 128).bind(64)).cumsum().clone(dev)
self.assertEqual(t.device, dev)
np.testing.assert_equal(t[:64].numpy(), np.arange(64).cumsum())
class TestLimitBufs(unittest.TestCase):
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV", "crashes in ocelot")
def test_limit_bufs_with_var(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
bufs = [Tensor([1]*10).contiguous().realize() for i in range(N)]
vi = Variable("i", 0, 9).bind(1)
vj = Variable("j", 0, 9).bind(2)
root = bufs[0][vi] + bufs[0][vj]
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
self.assertEqual(root.item(), N * 2)
def test_limit_bufs_arange_condition(self):
# WHERE with arange-based condition (pure index math, no device) and many buffer loads should not crash limit_bufs
with Context(MAX_KERNEL_BUFFERS=8):
N = 8
idx = Tensor.arange(N)
base = Tensor.zeros(N)
for i in range(4):
a, b = Tensor.rand(N).realize(), Tensor.rand(N).realize()
base = (idx >= i).where(a + b, base)
assert all(x > 0 for x in base.tolist())
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(32, 32).realize()
r = (a+a).sum(1).sum(0)
# double reduce collapses to a single reduce
run_linear(*check_schedule(r, 1))
self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0))
def test_single_swizzle(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(4, 1).realize()
b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize()
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
r = a.sum(0)+b
run_linear(*check_schedule(r, 1))
self.assertEqual(r.numpy(), a.numpy().sum(0)+1)
def test_double_swizzle_possible(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(4,).realize()
b = Tensor.randint(4,).realize()
# parallel reduce!
add = a.sum(0)+b.sum(0)
run_linear(*check_schedule(add, 1))
self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0))
def test_swizzle_reduceop(self):
Tensor.manual_seed(0)
x = Tensor.randn(4,4).realize()
y = Tensor.randn(4,4,4).realize()
out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y
run_linear(*check_schedule(out, 2)) # TODO: 1?
np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy())
def test_permute_rewrite(self):
x = Tensor.randn(4, 4, 16).realize()
y = Tensor.randn(4, 1, 16).realize()
z = Tensor.randn(4, 4, 1).realize()
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
run_linear(*check_schedule(t, 2)) # TODO: 1?
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
@unittest.skip("TODO: this swizzle isn't resolvable when there's a mask")
def test_swizzle_failure_permute(self):
a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90)
b = Tensor.empty(45,65)
a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,))
b_reduce = b.sum(axis=(0,))
t = a_reduce+b_reduce
run_linear(*check_schedule(t, 1))
def test_parallel_reduce_possible(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 2, 2).realize()
y = Tensor.randn(4, 2, 2).realize()
t = x.sum(axis=1)+y.sum(axis=1)
run_linear(*check_schedule(t, 1))
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
# kernels can only have 1 or n in each dim
def test_dont_parallelize_different_n(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 2, 2).realize()
y = Tensor.randn(4, 3, 2).realize()
t = x.sum(axis=1)+y.sum(axis=1)
run_linear(*check_schedule(t, 1))
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
def test_unsafe_pad(self):
x = Tensor.full((2,2), 1.0).contiguous()
y = x*x.sum((1,)).reciprocal()
t = y.pad(((0,1),None))
run_linear(*check_schedule(t, 3))
np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]])
zero_pm = UPat(Ops.CONST, arg=0)
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST Ops
a = Tensor.rand(10, 10).realize()
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
run_linear(*sched)
np.testing.assert_equal(b.numpy(), 0)
def test_mask_dim_1(self):
# mask out dim = 1 works too
a = Tensor.rand(10, 10).realize()
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
run_linear(*sched)
np.testing.assert_equal(b.numpy(), 0)
def test_partial_mask(self):
# partial masked out does not degrade into CONST
a = Tensor.rand(10, 10).realize()
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
run_linear(*sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
# a*VIEW(x), where VIEW(x) = 0
# x collapses along with its children
def test_parent_view_collapses(self):
a = Tensor([1, 2])
b = Tensor.arange(3).clone()
bv = b.pad(((0, 2),))[-2:]
# this becomes a late a*0
late_mul = a*bv
run_linear(*check_schedule(late_mul, 2))
# the arange doesn't realize
#self.assertIsNone(b.uop.base.realized)
# mul doesn't realize
#self.assertIsNone(late_mul.uop.base.realized)
self.assertEqual(late_mul.tolist(), [0, 0])
# SINK has two branches:
# a*VIEW(x), where VIEW(x) = 0
# x+2
# as long as one child realizes, x does not collapse
def test_parent_multiple_children_no_collapse(self):
a = Tensor([1, 2])
b = Tensor.arange(3).clone()
bv = b.pad(((0, 2),))[-2:]
late_mul = a*bv
other_child = b+2
s = check_schedule([late_mul, other_child], 3)
# the arange becomes a BUFFER
self.assertIs(b.uop.base.op, Ops.BUFFER)
# NOTE: no longer checked
# mul still collapses
#self.assertIs(late_mul.uop.base.op, Ops.CONST)
run_linear(*s)
self.assertEqual(other_child.tolist(), [2, 3, 4])
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
class TestCopyFolding(unittest.TestCase):
def test_const_copy_is_free(self):
b = Tensor(1).to("CPU") * 4
run_linear(*check_schedule(b, 0, filter_sink=False))
assert b.item() == 4
def test_one_hot_with_copy(self):
y = Tensor([1, 2, 3]).to("CPU")
x = y.one_hot(10)
check_schedule(x, 3, filter_sink=False)
def test_late_const_copy_folding(self):
a = Tensor.arange(3).clone().realize()
zeros = Tensor.zeros(3, buffer=False).realize()
b = (a*zeros).to("CPU") + 1
run_linear(*check_schedule(b, 1, filter_sink=False))
self.assertListEqual(b.tolist(), [1, 1, 1])
self.assertEqual(b.device, "CPU")
def test_alu_after_copy(self):
a = Tensor.ones((4,)).to("CPU")
b = Tensor.empty(4, device="CPU")
add = a+b
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
add.schedule_linear()
def test_alu_before_copy(self):
buf = Tensor.ones(1).contiguous().realize()
a = buf+1
b = a.to("CPU")
self.assertListEqual(b.tolist(), [2.])
def test_copy_to_same_device(self):
a = Tensor.empty(4).uop
b = a.copy_to_device(a.device)
check_schedule(b, 1, filter_sink=False) # TODO: 0?
def test_copy_to_same_device_alt(self):
a = Tensor.empty(4, 4).uop
b = a.copy_to_device(a.device)
check_schedule(b, 1, filter_sink=False) # TODO: 0?
def test_copy_to_same_device_sched(self):
a = Tensor.ones(4).contiguous().realize().uop.buf_uop
t = Tensor(a.copy_to_device(a.device))
linear, var_vals = t.linear_with_vars()
assert len([call for call in linear.src if call.src[0].op is Ops.COPY]) == 0
run_linear(linear, var_vals)
assert t.uop.is_realized, f"didn't realize Tensor {t}"
self.assertListEqual(t.tolist(), [1.,1.,1.,1.])
def test_self_assign_same_device_copy(self):
a = Tensor.ones(4, 4).contiguous().realize()
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
run_linear(*check_schedule(a, 2, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*4]*4)
def test_clone(self):
a = Tensor.empty(4)
check_schedule(a.clone(), 1, filter_sink=False)
def test_shrink_copy(self):
a = Tensor.arange(4)
view = a.shrink(((0, 2),))
b = view.clone()
run_linear(*check_schedule(b, 1, filter_sink=False))
self.assertEqual(b.uop.base.buffer.size, 2)
self.assertEqual(b.uop.numel(), 2)
self.assertListEqual(b.tolist(), [0, 1])
def test_expanded_copy(self):
a = Tensor.arange(2)
view = a.reshape(2, 1).expand(2, 2)
b = view.clone()
run_linear(*check_schedule(b, 1, filter_sink=False))
self.assertEqual(b.uop.base.buffer.size, 4)
self.assertEqual(b.uop.numel(), 4)
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
def test_permuted_copy(self):
a = Tensor.arange(4)
b = a.reshape(2, 2).permute(1, 0)
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).clone().realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).clone().realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_after_shrink(self):
a = Tensor.arange(5)
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
# NOTE: disk permute must come after COPY
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).clone().realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
class TestUOpBecome(unittest.TestCase):
def test_setitem_offset(self):
a = Tensor.full((16,), 0.).contiguous().realize()
b = Tensor.full((16,), 1.).contiguous().realize()
a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,))
b.shrink(((0,4),)).assign(a_view).realize()
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):
def test(contig=False):
bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
x = bt.permute(1,0)
if contig: x = x.contiguous()
return (x.permute(1,0) + bt).data()
assert test() == test(True)
def test_expand_fuse(self):
bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
out = (bt*2).expand(10,10).sum(1)
run_linear(*out.linear_with_vars())
outd = out.tolist()
assert all(x == 20.0 for x in outd)
if __name__ == '__main__':
unittest.main(verbosity=2)