mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix assert in test_schedule (#12745)
* fix assert in test_schedule updated kernel counts and some old tests * fix
This commit is contained in:
parent
285534ce64
commit
9561803cb0
1 changed files with 89 additions and 155 deletions
|
|
@ -2,9 +2,8 @@
|
|||
# schedule confirms the right things are capable of fusing
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
|
||||
import unittest
|
||||
import unittest, functools
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import cast
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
|
||||
|
|
@ -31,7 +30,6 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
|||
# test lowering all the ScheduleItems to ExecItems
|
||||
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
||||
if kernel_cnt != allowed:
|
||||
return sched # allow different kernel count, TODO: fix the asserts
|
||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if DEBUG >= 3:
|
||||
for i,s in enumerate(sched):
|
||||
|
|
@ -117,8 +115,7 @@ class TestSchedule(unittest.TestCase):
|
|||
c = a+b
|
||||
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half) and getenv("CAST_AFTER_EXPAND"), "need half and CAST_AFTER_EXPAND=1")
|
||||
@unittest.skip("CAST_AFTER_EXPAND is not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_expand_buffer_before_cast(self):
|
||||
a = Tensor.randn(4, 2, 1).realize().permute((1, 0, 2))
|
||||
b = a.cast(dtypes.half).expand((2, 4, 4))+2
|
||||
|
|
@ -128,7 +125,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_indexing_scalars_simple(self):
|
||||
X = Tensor.randn(2, 2).realize()
|
||||
xt = X[Tensor(1)][Tensor(0)]
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
run_schedule(check_schedule(xt, 1))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
|
||||
|
|
@ -148,30 +145,30 @@ class TestSchedule(unittest.TestCase):
|
|||
assume(a<x and b<y)
|
||||
X = Tensor.randn(x, y).realize()
|
||||
xt = X[Tensor(a)][Tensor(b)]
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
run_schedule(check_schedule(xt, 1))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[a][b])
|
||||
|
||||
def test_push_pads_elementwise(self):
|
||||
x = Tensor.full((4,4), 2.).contiguous().realize()
|
||||
y = Tensor.full((4,4), 4.).contiguous().realize()
|
||||
z = (x.reciprocal()*y).pad((None, (0,1),)).sum()
|
||||
run_schedule(check_schedule(z, 2))
|
||||
run_schedule(check_schedule(z, 1))
|
||||
self.assertEqual(z.item(), 32)
|
||||
|
||||
def test_push_pads_contiguous(self):
|
||||
x = Tensor.full((4,1), 2.).contiguous()
|
||||
y = Tensor.full((4,4), 4.).contiguous()
|
||||
z = (x.reciprocal().expand(4,4)*y).pad((None, (0,1),)).sum()
|
||||
run_schedule(check_schedule(z, 2, [x,y]))
|
||||
run_schedule(check_schedule(z, 1, [x,y]))
|
||||
self.assertEqual(z.item(), 32)
|
||||
|
||||
def test_rand(self):
|
||||
x = Tensor.rand(32)
|
||||
check_schedule(x, 4, [Tensor._device_rng_counters[x.device]])
|
||||
check_schedule(x, 1, [Tensor._device_rng_counters[x.device]])
|
||||
|
||||
def test_rand_recompute_arange(self):
|
||||
x = Tensor.rand(32)
|
||||
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
|
||||
check_schedule(x, 1, [Tensor._device_rng_counters[x.device]])
|
||||
|
||||
def test_empty_is_not_realized(self):
|
||||
a = Tensor.empty(10)
|
||||
|
|
@ -188,10 +185,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_simplify_padded_const(self):
|
||||
a = Tensor.empty(1022).cummax(axis=0)
|
||||
check_schedule(a, 5)
|
||||
# TODO: what is this testing?
|
||||
#ast = sched[0].ast
|
||||
#self.assertLessEqual(len([u for u in ast.toposort() if u.op is Ops.WHERE]), 6)
|
||||
check_schedule(a, 3)
|
||||
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
|
|
@ -264,12 +258,11 @@ class TestSchedule(unittest.TestCase):
|
|||
c = a.sum(axis=0) + b
|
||||
check_schedule(c, 1)
|
||||
|
||||
# not pushing permutes through reduces
|
||||
def test_reduce_permute_binop_fusion(self):
|
||||
a = Tensor.empty(10,10,10)
|
||||
b = Tensor.empty(10,10,1)
|
||||
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
|
||||
check_schedule(c, 2)
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_allow_push_permutes(self):
|
||||
a = Tensor.randn(10,10,10).realize()
|
||||
|
|
@ -340,7 +333,7 @@ class TestSchedule(unittest.TestCase):
|
|||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out0 = r0 + y
|
||||
out1 = r1 + y
|
||||
schedule = check_schedule([out0, out1], 2)
|
||||
schedule = check_schedule([out0, out1], 4)
|
||||
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
|
||||
self.assertEqual(len(reduceops), 2) # why is RANGEIFY different?
|
||||
|
||||
|
|
@ -373,7 +366,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b = Tensor.full((4,), 2.).contiguous()
|
||||
first = a.assign(b)
|
||||
second = a.assign(b)
|
||||
check_schedule([first, second], 1)
|
||||
check_schedule([first, second], 2) # TODO: 1?
|
||||
|
||||
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
|
||||
# should contiguous dedup?
|
||||
|
|
@ -453,7 +446,7 @@ class TestSchedule(unittest.TestCase):
|
|||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 11)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 13)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
|
|
@ -474,7 +467,7 @@ class TestSchedule(unittest.TestCase):
|
|||
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||
fw.sum().backward()
|
||||
# TODO: this is too many
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
|
@ -517,9 +510,8 @@ class TestSchedule(unittest.TestCase):
|
|||
img = Tensor.empty(64,64)
|
||||
x = (img.sum(0) + img.sum(1))
|
||||
out = x.relu()
|
||||
check_schedule(out, 2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
#@unittest.skip("failing in old lazy")
|
||||
def test_push_permute_through_reshape(self):
|
||||
a = Tensor.empty(16,16)
|
||||
b = Tensor.empty(16,16)
|
||||
|
|
@ -553,7 +545,7 @@ class TestSchedule(unittest.TestCase):
|
|||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
out = c.sum() + d.sum()
|
||||
check_schedule(out, 2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_children_dont_push(self):
|
||||
a = Tensor.empty(10, 10, 1)
|
||||
|
|
@ -561,7 +553,7 @@ class TestSchedule(unittest.TestCase):
|
|||
d = (a+b).expand(10, 10, 10)
|
||||
e = (a+b).permute(2,1,0)
|
||||
f = d+e
|
||||
check_schedule(f, 2)
|
||||
check_schedule(f, 1)
|
||||
|
||||
# failing in new lazy
|
||||
@unittest.skip("always fusing elementwise")
|
||||
|
|
@ -600,13 +592,13 @@ class TestSchedule(unittest.TestCase):
|
|||
e = c[0] * d
|
||||
check_schedule(e, 1)
|
||||
|
||||
def test_expand_nofuse(self):
|
||||
def test_expand_fuse(self):
|
||||
a = Tensor.empty(1, 16)
|
||||
b = Tensor.empty(1, 16)
|
||||
c = a * b
|
||||
d = Tensor.empty(8192, 16)
|
||||
e = c * d
|
||||
check_schedule(e, 2)
|
||||
check_schedule(e, 1)
|
||||
|
||||
# this is the failing case in openpilot...it's very simple like this
|
||||
def test_image_conv_fusion(self):
|
||||
|
|
@ -624,7 +616,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
# NOOP, 3 convs, contiguous
|
||||
#check_schedule(x, 5)
|
||||
check_schedule(x, 8)
|
||||
check_schedule(x, 7)
|
||||
|
||||
def test_image_conv_fusion_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
|
|
@ -807,13 +799,13 @@ class TestSchedule(unittest.TestCase):
|
|||
x = Tensor.empty(32, 32, 32)
|
||||
y = Tensor.empty(32, 32)
|
||||
out = x.sum(axis=2).T+y
|
||||
check_schedule(out, 2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_two_elus_sum(self):
|
||||
x = Tensor.empty(32, 32)
|
||||
y = Tensor.empty(32, 32)
|
||||
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
|
||||
check_schedule(out, 2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
@unittest.skipUnless(SPLIT_REDUCEOP, "Testing split reducop requires SPLIT_REDUCEOP")
|
||||
def test_preserve_multistage_reduce(self):
|
||||
|
|
@ -826,7 +818,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_multistage_reduce(self):
|
||||
x = Tensor.empty(32, 32, 32)
|
||||
out = x.sum(2).relu().sum(1)
|
||||
check_schedule(out, 2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_multistage_reduce_fork(self):
|
||||
x = Tensor.empty(32, 32, 32)
|
||||
|
|
@ -842,7 +834,7 @@ class TestSchedule(unittest.TestCase):
|
|||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_contig(self):
|
||||
|
|
@ -851,7 +843,7 @@ class TestSchedule(unittest.TestCase):
|
|||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_same(self):
|
||||
|
|
@ -859,7 +851,7 @@ class TestSchedule(unittest.TestCase):
|
|||
z = x.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
# NOTE: the gradient flows twice
|
||||
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
|
||||
|
||||
|
|
@ -882,8 +874,7 @@ class TestSchedule(unittest.TestCase):
|
|||
x = x.sum(1)
|
||||
x = x[:16]
|
||||
out = x + y
|
||||
# NOTE: this could be 1 kernel if we mask the store?
|
||||
check_schedule(out, 2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_multireduce_shrink(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
|
@ -895,8 +886,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b_out = b.sum(1)
|
||||
b_out = b_out[:16]
|
||||
out = a_out + b_out + c
|
||||
# run_schedule(check_schedule(out, 2)) # TODO: this should be 1 (can we make it 1 with the new linearizer?)
|
||||
run_schedule(check_schedule(out, 3))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), a.numpy().sum(axis=1)[:16] + b.numpy().sum(axis=1)[:16] + c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# broken due to const folding and two contiguous are different kernels
|
||||
|
|
@ -913,7 +903,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + 4
|
||||
out2 = out0 * out1
|
||||
run_schedule(check_schedule([out0, out1, out2], 1))
|
||||
run_schedule(check_schedule([out0, out1, out2], 3)) # TODO: 1?
|
||||
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
|
||||
|
|
@ -924,7 +914,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out0 = a.sum().exp2()
|
||||
# out1 has two paths to a.sum()
|
||||
out1 = a.sum() + out0
|
||||
run_schedule(check_schedule([out0, out1], 1))
|
||||
run_schedule(check_schedule([out0, out1], 2)) # TODO: 1?
|
||||
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6)
|
||||
|
||||
|
|
@ -937,7 +927,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out2 = b.sum().exp2()
|
||||
out3 = b.sum() + out2
|
||||
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
|
||||
run_schedule(check_schedule([out0, out1, out2, out3], 6))
|
||||
run_schedule(check_schedule([out0, out1, out2, out3], 4))
|
||||
np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
|
||||
np_b = (a.numpy() + np_out0 + np_out1)
|
||||
|
|
@ -952,7 +942,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out0 = a.sum() + b.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
# run_schedule(check_schedule([out0, out1], 1))
|
||||
run_schedule(check_schedule([out0, out1], 4))
|
||||
run_schedule(check_schedule([out0, out1], 2))
|
||||
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
@ -979,7 +969,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out1 = b.max() + out0*2
|
||||
out2 = a.sum() + out1
|
||||
# run_schedule(check_schedule([out0, out1, out2], 1))
|
||||
run_schedule(check_schedule([out0, out1, out2], 4))
|
||||
run_schedule(check_schedule([out0, out1, out2], 3))
|
||||
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
|
||||
|
|
@ -1016,7 +1006,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b = Tensor.empty(10,)
|
||||
c = a.sum() + b[0]
|
||||
d = a.sum() + 2
|
||||
check_schedule([c, d], 1)
|
||||
check_schedule([c, d], 2) # TODO: 1?
|
||||
|
||||
def test_reduce_multiple_paths_midshrink(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
|
|
@ -1045,7 +1035,7 @@ class TestSchedule(unittest.TestCase):
|
|||
k = Tensor.randn(32,8,16,8).realize()
|
||||
v = Tensor.randn(32,8,16,8).realize()
|
||||
out = Tensor.scaled_dot_product_attention(q,k,v)
|
||||
run_schedule(check_schedule(out, 5))
|
||||
run_schedule(check_schedule(out, 4))
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
|
||||
|
|
@ -1053,7 +1043,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
with Context(FUSE_ATTENTION=1):
|
||||
out = Tensor.scaled_dot_product_attention(q,k,v)
|
||||
run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 4)) # TODO: should be 1?
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
|
||||
|
|
@ -1066,7 +1056,7 @@ class TestSchedule(unittest.TestCase):
|
|||
c = Tensor.randn(4, 32).realize()
|
||||
out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 3))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), \
|
||||
(c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
@ -1111,8 +1101,7 @@ class TestSchedule(unittest.TestCase):
|
|||
x = Tensor.randn(4, 32).realize()
|
||||
y = Tensor.randn(4, 32).realize()
|
||||
out = y.sum(axis=-1) + x.sum(axis=-1)
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_multireduce_fusion_sequential(self):
|
||||
|
|
@ -1129,7 +1118,7 @@ class TestSchedule(unittest.TestCase):
|
|||
y = Tensor.randn(4, 32).realize()
|
||||
out = x.std(-1) + y.std(-1)
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 4))
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_multireduce_diffops_sequential(self):
|
||||
|
|
@ -1145,8 +1134,7 @@ class TestSchedule(unittest.TestCase):
|
|||
x = Tensor.randn(4, 32).realize()
|
||||
y = Tensor.randn(4, 32).realize()
|
||||
out = x.sum(-1) + y.max(-1)
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_multireduce_fusion_sequential_and_parallel(self):
|
||||
|
|
@ -1158,7 +1146,7 @@ class TestSchedule(unittest.TestCase):
|
|||
np_mu = (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True) + \
|
||||
(y.numpy() - y.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True)
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 6))
|
||||
run_schedule(check_schedule(out, 5))
|
||||
np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
@ -1167,8 +1155,7 @@ class TestSchedule(unittest.TestCase):
|
|||
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
|
||||
c,d = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
|
||||
out = a@b + c@d
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), a.numpy()@b.numpy() + c.numpy()@d.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_softmax_fusion(self):
|
||||
|
|
@ -1179,17 +1166,15 @@ class TestSchedule(unittest.TestCase):
|
|||
expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
# TODO: rangeify stores the output in float32
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.expectedFailure
|
||||
def test_softmax_upcast(self):
|
||||
# input half, softmax in float
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
|
||||
out = x.softmax(dtype=dtypes.float)
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 2)
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
|
||||
self.assertEqual(len(sched), 3)
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
|
||||
|
||||
# input float, softmax in float
|
||||
Tensor.manual_seed(0)
|
||||
|
|
@ -1221,12 +1206,12 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_scaled_dot_product_attention_fusion(self):
|
||||
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
|
||||
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m)
|
||||
check_schedule(out, 5)
|
||||
check_schedule(out, 4)
|
||||
|
||||
def test_scaled_dot_product_attention_causal_fusion(self):
|
||||
x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3))
|
||||
out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True)
|
||||
check_schedule(out, 5)
|
||||
check_schedule(out, 4)
|
||||
|
||||
def test_adam_step_fusion(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -1256,7 +1241,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 20)
|
||||
check_schedule(opt.schedule_step(), 18)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -1266,7 +1251,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.SGD(nn.state.get_parameters(c1))
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 3)
|
||||
check_schedule(opt.schedule_step(), 5) # TODO: 3?
|
||||
|
||||
def test_sgd_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -1289,7 +1274,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
check_schedule(opt.schedule_step(), 15)
|
||||
|
||||
def test_sgd_4convs_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -1302,7 +1287,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 17)
|
||||
check_schedule(opt.schedule_step(), 15)
|
||||
|
||||
def test_sgd_4convs_fuse_conv_bw(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -1315,50 +1300,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.expectedFailure
|
||||
def test_prefer_half_buffer(self):
|
||||
x = Tensor.ones(4).contiguous().realize()
|
||||
# y = Tensor.ones(4).contiguous().realize()
|
||||
z = Tensor.ones(4, 4).contiguous().realize()
|
||||
|
||||
# should not create extra kernel if output will be realized anyways
|
||||
dummy = x.sum().half().float()
|
||||
check_schedule(dummy, 1)
|
||||
dummy = x.sum().half().float().contiguous() + 1
|
||||
check_schedule(dummy, 2)
|
||||
|
||||
# shared between two outputs
|
||||
shared = x.sum().half().float()
|
||||
a = shared * 2
|
||||
b = shared * 3
|
||||
sched = check_schedule([a, b], 3)
|
||||
# store reduceop in half
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
|
||||
# fuse cast with the child kernel
|
||||
self.assertEqual(sched[1].bufs[0].dtype, dtypes.float)
|
||||
self.assertEqual(sched[2].bufs[0].dtype, dtypes.float)
|
||||
|
||||
# reduce
|
||||
a = z.sum(axis=0).half().float().sum(axis=0)
|
||||
sched = check_schedule(a, 2)
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
|
||||
self.assertEqual(sched[1].bufs[0].dtype, dtypes.float)
|
||||
|
||||
# expand
|
||||
# expand will realize just after the .float(), so requires change to realize-before-expand
|
||||
# normal = (x.sum().half().float().reshape(1) * y).sum()
|
||||
# sched = check_schedule(normal, 2)
|
||||
# for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs[:-1])
|
||||
|
||||
# parallel reduce
|
||||
# a = x.sum().half().float() * y.sum().half().float()
|
||||
# b = a + 1
|
||||
# c = a + 2
|
||||
# sched = check_schedule([b, c], 4)
|
||||
# doesn't store either in half because it doesn't chase
|
||||
check_schedule(opt.schedule_step(), 15)
|
||||
|
||||
def test_reduce_simple_chase(self):
|
||||
a = Tensor.empty(4, 4, 4)
|
||||
|
|
@ -1407,7 +1349,7 @@ class TestSchedule(unittest.TestCase):
|
|||
c = Tensor.empty(16, )
|
||||
r = a.sum(1) + c
|
||||
d = r[:4] * b
|
||||
check_schedule(d, 2)
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_multireduce_push_shrink_chase(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
|
@ -1417,22 +1359,20 @@ class TestSchedule(unittest.TestCase):
|
|||
d = Tensor.randn(16, 16).realize()
|
||||
r = a.sum(1) + c
|
||||
out = r[:4] * b + d.sum(1)[:4]
|
||||
# schedule = check_schedule(out, 2)
|
||||
schedule = check_schedule(out, 3)
|
||||
schedule = check_schedule(out, 1)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_midreduce_nochase(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = (a.sum(0) + a.max(1)) + 2
|
||||
check_schedule(b, 2)
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_multireduce_midreduce_nochase(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(16, 16).realize()
|
||||
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
|
||||
# schedule = check_schedule(b, 2)
|
||||
schedule = check_schedule(b, 4)
|
||||
schedule = check_schedule(b, 1)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
@ -1444,7 +1384,7 @@ class TestSchedule(unittest.TestCase):
|
|||
c = a.sum() + 2
|
||||
d = (a.sum() - b.sum()) * 4
|
||||
# run_schedule(check_schedule([c, d], 1))
|
||||
run_schedule(check_schedule([c, d], 3))
|
||||
run_schedule(check_schedule([c, d], 2))
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
@ -1470,7 +1410,7 @@ class TestSchedule(unittest.TestCase):
|
|||
e = c * d
|
||||
f = b.sum() - e
|
||||
# run_schedule(check_schedule([c, d, e, f], 1))
|
||||
run_schedule(check_schedule([c, d, e, f], 2))
|
||||
run_schedule(check_schedule([c, d, e, f], 4))
|
||||
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
|
||||
|
|
@ -1485,7 +1425,7 @@ class TestSchedule(unittest.TestCase):
|
|||
e = c * d
|
||||
f = (b - d).sum() - e
|
||||
# run_schedule(check_schedule([c, d, e, f], 1))
|
||||
run_schedule(check_schedule([c, d, e, f], 5))
|
||||
run_schedule(check_schedule([c, d, e, f], 4))
|
||||
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
|
||||
|
|
@ -1504,8 +1444,7 @@ class TestSchedule(unittest.TestCase):
|
|||
a = Tensor.randn(3, 4, 5).realize()
|
||||
b = Tensor.randn(3, 4, 5).realize()
|
||||
out = (a.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()).contiguous()
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \
|
||||
np.pad(b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
@ -1513,7 +1452,7 @@ class TestSchedule(unittest.TestCase):
|
|||
Tensor.manual_seed(0)
|
||||
a = Tensor.rand(3, 4, 5).realize()
|
||||
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_multireduce_pad_reduce_unsafe(self):
|
||||
|
|
@ -1522,7 +1461,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b = Tensor.randn(3, 4, 5).abs().realize()
|
||||
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 4))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
|
||||
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-5)
|
||||
|
||||
|
|
@ -1536,7 +1475,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_shrink_pad_unsafe(self):
|
||||
a = Tensor.ones((3, )).contiguous().realize()
|
||||
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
|
||||
def test_base_change_shrink_pad(self):
|
||||
|
|
@ -1544,7 +1483,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b = a.exp2()
|
||||
c = b[:-1, :-1]
|
||||
d = c.pad(((0, 1), (0, 1))) * 2
|
||||
run_schedule(check_schedule(d, 2))
|
||||
run_schedule(check_schedule(d, 1))
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
|
||||
|
||||
def test_base_change_expand_pad(self):
|
||||
|
|
@ -1552,14 +1491,14 @@ class TestSchedule(unittest.TestCase):
|
|||
b = a.exp2()
|
||||
c = b[:, None, :]
|
||||
d = c.pad(((0, 0), (1, 1), (0, 0))) * 2
|
||||
run_schedule(check_schedule(d, 2))
|
||||
run_schedule(check_schedule(d, 1))
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
|
||||
|
||||
def test_fuse_arange_pad_replicate_mode(self):
|
||||
x = Tensor.empty(3,3,3,3, requires_grad=True)
|
||||
y = x.pad((-1,2,2,-1), mode="replicate")
|
||||
dx = y.sum().gradient(x)[0]
|
||||
sched = check_schedule(dx, 3)
|
||||
sched = check_schedule(dx, 1)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
|
||||
|
||||
|
|
@ -1569,7 +1508,7 @@ class TestSchedule(unittest.TestCase):
|
|||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.cast(dtypes.half).expand(2, 4, 4)
|
||||
c = b.cast(dtypes.int).expand(2, 2, 4, 4)
|
||||
run_schedule(check_schedule(c, 2))
|
||||
run_schedule(check_schedule(c, 1))
|
||||
np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32))
|
||||
|
||||
def test_base_change_pad_expand(self):
|
||||
|
|
@ -1577,7 +1516,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b = Tensor.full((4, 4), 2.).contiguous().realize()
|
||||
c = (a + b).pad(((1, 1), (1, 1)))
|
||||
d = c.cast(dtypes.int).expand((2, 6, 6)) * 4
|
||||
run_schedule(check_schedule(d, 2))
|
||||
run_schedule(check_schedule(d, 1))
|
||||
c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0)
|
||||
np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4)
|
||||
|
||||
|
|
@ -1676,7 +1615,7 @@ class TestSchedule(unittest.TestCase):
|
|||
self._test_fusion([(4, 4), (1, 4)], lambda a,b:a.sum(1).reshape(b.shape)+b, 1)
|
||||
|
||||
def test_late_fusion_post_permute(self):
|
||||
self._test_fusion([(4, 6, 4), (4, 4, 1)], lambda a,b:a.sum(1, keepdim=True).permute((2, 0, 1))+b, 2)
|
||||
self._test_fusion([(4, 6, 4), (4, 4, 1)], lambda a,b:a.sum(1, keepdim=True).permute((2, 0, 1))+b, 1)
|
||||
|
||||
def test_late_fusion_double_transpose(self):
|
||||
self._test_fusion([(32, 16, 1)],
|
||||
|
|
@ -1714,6 +1653,7 @@ class TestSchedule(unittest.TestCase):
|
|||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
|
||||
@unittest.skip("kernel count depends on input")
|
||||
def test_cast_padded_const(self, dt1, dt2):
|
||||
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
|
||||
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
|
||||
|
|
@ -1727,7 +1667,7 @@ class TestSchedule(unittest.TestCase):
|
|||
X = Tensor.randn(10, 10).realize()
|
||||
idxs = Tensor([0, 2]).realize()
|
||||
xt = X[idxs]
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
run_schedule(check_schedule(xt, 1))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
|
||||
|
||||
def test_simple_indexing_alt(self):
|
||||
|
|
@ -1745,7 +1685,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_advanced_indexing_alt(self):
|
||||
X = Tensor.arange(6).reshape(3, 2)+1
|
||||
xt = X[[Tensor([2]), Tensor([1])]]
|
||||
run_schedule(check_schedule(xt, 3))
|
||||
run_schedule(check_schedule(xt, 1))
|
||||
np.testing.assert_equal(xt.numpy(), 6)
|
||||
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
|
|
@ -1793,7 +1733,7 @@ class TestSchedule(unittest.TestCase):
|
|||
x = Tensor.full((2,2), 16)
|
||||
y = x.idiv(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2)).pad(((1,1), (1,1)))
|
||||
out = y.sum(axis=1)
|
||||
run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
self.assertListEqual(out.tolist(), [0, 12, 4, 0])
|
||||
|
||||
def test_arange_transposed_descendants(self):
|
||||
|
|
@ -1826,7 +1766,7 @@ class TestSchedule(unittest.TestCase):
|
|||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
run_schedule(check_schedule(out, 3))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_child(self):
|
||||
|
|
@ -1842,7 +1782,7 @@ class TestSchedule(unittest.TestCase):
|
|||
x = Tensor.randn(5, 2).realize()
|
||||
a = (Tensor.arange(10)+1).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
run_schedule(check_schedule(out, 3))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
@unittest.skip("BUFFER_VIEW no longer supported on non-disk devices")
|
||||
|
|
@ -1857,10 +1797,10 @@ class TestSchedule(unittest.TestCase):
|
|||
from extra.models.llama import precompute_freqs_cis
|
||||
args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000}
|
||||
fused = precompute_freqs_cis(**args)
|
||||
run_schedule(check_schedule(fused, 3))
|
||||
run_schedule(check_schedule(fused, 1))
|
||||
if getenv("CHECK", 1):
|
||||
ref = precompute_freqs_cis(**args)
|
||||
run_schedule(check_schedule(ref, 3))
|
||||
run_schedule(check_schedule(ref, 1))
|
||||
np.testing.assert_equal(fused.numpy(), ref.numpy())
|
||||
|
||||
def test_fuse_assign_contiguous(self):
|
||||
|
|
@ -1902,7 +1842,7 @@ class TestSchedule(unittest.TestCase):
|
|||
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
|
||||
Y = Tensor([1, 2]).realize()
|
||||
loss = X.sparse_categorical_crossentropy(Y)
|
||||
run_schedule(check_schedule(loss, 4))
|
||||
run_schedule(check_schedule(loss, 3))
|
||||
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_const_folding_alt(self):
|
||||
|
|
@ -1923,7 +1863,7 @@ class TestSchedule(unittest.TestCase):
|
|||
yt = Tensor.randn(BS, 10).realize()
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
||||
run_schedule(check_schedule(loss, 6))
|
||||
run_schedule(check_schedule(loss, 5))
|
||||
loss_fused = loss.numpy()
|
||||
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
|
||||
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
|
@ -1933,7 +1873,7 @@ class TestSchedule(unittest.TestCase):
|
|||
r = (X+Tensor.arange(16).reshape(4, 4)).sum()
|
||||
out0 = r+2
|
||||
out1 = r+3
|
||||
run_schedule(check_schedule([out0, out1], 1))
|
||||
run_schedule(check_schedule([out0, out1], 2)) # TODO: 1?
|
||||
r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum()
|
||||
np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7)
|
||||
np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7)
|
||||
|
|
@ -2003,21 +1943,21 @@ class TestSwizzle(unittest.TestCase):
|
|||
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
||||
a = Tensor.randn(32, 32).realize()
|
||||
t = a.softmax()
|
||||
check_schedule(t, 1)
|
||||
check_schedule(t, 3) # TODO: 1?
|
||||
|
||||
def test_argmax_one_kernel(self):
|
||||
Tensor.manual_seed(0)
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
||||
a = Tensor.randn(10, 20).realize()
|
||||
t = a.argmax(0)
|
||||
check_schedule(t, 1)
|
||||
check_schedule(t, 2) # TODO: 1?
|
||||
|
||||
def test_swizzle_reduceop(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4,4).realize()
|
||||
y = Tensor.randn(4,4,4).realize()
|
||||
out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y
|
||||
run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2)) # TODO: 1?
|
||||
np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy())
|
||||
|
||||
def test_permute_rewrite(self):
|
||||
|
|
@ -2025,7 +1965,7 @@ class TestSwizzle(unittest.TestCase):
|
|||
y = Tensor.randn(4, 1, 16).realize()
|
||||
z = Tensor.randn(4, 4, 1).realize()
|
||||
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
|
||||
run_schedule(check_schedule(t, 1))
|
||||
run_schedule(check_schedule(t, 2)) # TODO: 1?
|
||||
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
|
||||
|
||||
|
|
@ -2145,7 +2085,7 @@ class TestCopyFolding(unittest.TestCase):
|
|||
a = Tensor.arange(3).realize()
|
||||
zeros = Tensor.zeros(3).realize()
|
||||
b = (a*zeros).to("CPU")
|
||||
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False)) # TODO: 0?
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
self.assertEqual(b.device, "CPU")
|
||||
|
||||
|
|
@ -2165,12 +2105,12 @@ class TestCopyFolding(unittest.TestCase):
|
|||
def test_copy_to_same_device(self):
|
||||
a = Tensor.empty(4).uop
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
check_schedule(b, 1, filter_sink=False) # TODO: 0?
|
||||
|
||||
def test_copy_to_same_device_alt(self):
|
||||
a = Tensor.empty(4, 4).uop
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
check_schedule(b, 1, filter_sink=False) # TODO: 0?
|
||||
|
||||
def test_copy_to_same_device_sched(self):
|
||||
a = Tensor.ones(4).contiguous().realize().uop.as_buf()
|
||||
|
|
@ -2185,13 +2125,11 @@ class TestCopyFolding(unittest.TestCase):
|
|||
a = Tensor.empty(4)
|
||||
check_schedule(a.clone(), 1, filter_sink=False)
|
||||
|
||||
# NOTE: moving copy before view might change this
|
||||
def test_shrink_copy(self):
|
||||
a = Tensor.arange(4)
|
||||
view = a.shrink(((0, 2),))
|
||||
b = view.clone()
|
||||
# NOTE: this was sort of a bug making this 2
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
run_schedule(check_schedule(b, 1, filter_sink=False))
|
||||
self.assertEqual(b.uop.base.buffer.size, 2)
|
||||
self.assertEqual(b.uop.size, 2)
|
||||
self.assertListEqual(b.tolist(), [0, 1])
|
||||
|
|
@ -2200,7 +2138,7 @@ class TestCopyFolding(unittest.TestCase):
|
|||
a = Tensor.arange(2)
|
||||
view = a.reshape(2, 1).expand(2, 2)
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
run_schedule(check_schedule(b, 1, filter_sink=False))
|
||||
self.assertEqual(b.uop.base.buffer.size, 4)
|
||||
self.assertEqual(b.uop.size, 4)
|
||||
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
|
||||
|
|
@ -2323,7 +2261,7 @@ class TestContiguous(unittest.TestCase):
|
|||
def test_double_contiguous_realizes_once(self):
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
check_schedule(b, 2) # TODO: should be 1?
|
||||
|
||||
def test_view_does_not_realize(self):
|
||||
a = Tensor.empty(4)
|
||||
|
|
@ -2459,10 +2397,6 @@ class TestUOpBecome(unittest.TestCase):
|
|||
c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
|
||||
check_schedule([b, c], 0)
|
||||
assert all_same([x.uop.base.realized for x in [a,b,c]])
|
||||
# these movement ops result in the same ShapeTracker
|
||||
assert b.uop.st == c.uop.st
|
||||
assert b.uop is c.uop
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {})
|
||||
|
||||
def test_setitem_becomes_subbuffer(self):
|
||||
a = Tensor.full((4,), 2.).contiguous().realize()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue