mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
476 lines
19 KiB
Python
476 lines
19 KiB
Python
# this will be the new test_ops for the next level
|
|
# schedule confirms the right things are capable of fusing
|
|
# NOTE: this has overlap with external_test_opt.py
|
|
|
|
import unittest
|
|
import numpy as np
|
|
|
|
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
|
from tinygrad.uop.ops import UOp, Ops, UPat
|
|
from tinygrad.helpers import DEBUG, DEV, GlobalCounters, Context, all_same, temp
|
|
from tinygrad.engine.realize import compile_linear, run_linear
|
|
|
|
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
|
|
|
class KernelCountException(Exception): pass
|
|
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
|
|
if to_prerealize:
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
|
|
if isinstance(t, Tensor): linear, var_vals = t.linear_with_vars()
|
|
elif isinstance(t, list) and isinstance(t[0], Tensor): linear, var_vals = Tensor.linear_with_vars(*t)
|
|
else:
|
|
assert isinstance(t, UOp), f"can't schedule {t}"
|
|
linear, var_vals = Tensor(t).linear_with_vars()
|
|
kernel_cnt = sum((len(call.device) if isinstance(call.device, tuple) else 1)
|
|
for call in linear.src if call.src[0].op is Ops.SINK or not filter_sink)
|
|
if kernel_cnt != allowed:
|
|
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
|
|
if DEBUG >= 3:
|
|
for i,call in enumerate(linear.src):
|
|
print("kernel", i+1)
|
|
print(call.src[0])
|
|
raise KernelCountException(f"{kernel_cnt} != {allowed}")
|
|
# test compiling the linear
|
|
compile_linear(linear)
|
|
return linear, var_vals
|
|
|
|
def _realize_weights(m):
|
|
for p in nn.state.get_parameters(m): p.realize()
|
|
|
|
class TestSchedule(unittest.TestCase):
|
|
def setUp(self):
|
|
self.ctx = Context(SPLIT_REDUCEOP=0)
|
|
self.ctx.__enter__()
|
|
def tearDown(self):
|
|
self.ctx.__exit__(None, None, None)
|
|
|
|
@unittest.skip("no longer supported")
|
|
def test_double_from(self):
|
|
x = Tensor([1,2,3,4])
|
|
out = x.to('python')
|
|
check_schedule(out, 0, filter_sink=False)
|
|
|
|
def test_example_matmul_same(self):
|
|
x = Tensor.eye(64).clone().realize()
|
|
z = x.matmul(x).sum()
|
|
z.backward()
|
|
out = x.grad.contiguous()
|
|
run_linear(*check_schedule(out, 1))
|
|
# NOTE: the gradient flows twice
|
|
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
|
|
|
|
def test_pad_reduce_scope_collision(self):
|
|
b = Tensor.rand(4, 3).realize()
|
|
s1 = b.pad(((1, 1), (0, 0))).sum(axis=1)
|
|
s2 = b.pad(((1, 2), (0, 0))).shrink(((0, 6), (0, 3))).sum(axis=1)
|
|
out = s1 + s2
|
|
run_linear(*check_schedule(out, 1))
|
|
np.testing.assert_allclose(out.numpy(), 2*np.pad(b.numpy(), ((1, 1), (0, 0))).sum(axis=1), rtol=1e-6)
|
|
|
|
def test_cumsum_parallel_reduce_fused(self):
|
|
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
|
|
step, num_steps = 513, 10
|
|
t = Tensor.arange(step).float().realize()
|
|
phase = t.cumsum()
|
|
tiled = phase.repeat((num_steps,)).reshape(num_steps, step)
|
|
pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)
|
|
out = (tiled * pattern).flatten()
|
|
expected = np.tile(np.arange(step).astype(np.float32).cumsum(), num_steps).reshape(num_steps, step)
|
|
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
|
|
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
|
|
def test_reduce_different_nesting_depth(self):
|
|
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
|
|
x = Tensor.arange(768).reshape(3, 256).float()
|
|
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
|
|
|
|
def test_fuse_assign_contiguous(self):
|
|
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
|
|
a = Tensor.arange(8).reshape(4, 2)
|
|
run_linear(*check_schedule(x.shrink((None, (0, 2))).assign(a.clone()), 2))
|
|
np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]])
|
|
|
|
def test_assign_non_contiguous_alt(self): self.test_assign_non_contiguous(alt=True)
|
|
def test_assign_non_contiguous(self, alt=False):
|
|
x = (Tensor.arange(16)-100).reshape(4,4).clone().realize()
|
|
xref = x.numpy()
|
|
if alt:
|
|
y = Tensor.randint(2, 4).contiguous().realize()
|
|
a = Tensor.arange(8).reshape(2, 4)+y
|
|
tst = x.shrink(((0, 2), None)).assign(a).realize()
|
|
xref[:2, :] = np.arange(8).reshape(2, 4)+y.numpy()
|
|
else:
|
|
y = Tensor.randint(4, 2).contiguous().realize()
|
|
a = Tensor.arange(8).reshape(4, 2)+y
|
|
tst = x.shrink((None, (0, 2))).assign(a).realize()
|
|
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
|
|
np.testing.assert_equal(x.numpy(), xref)
|
|
np.testing.assert_equal(tst.numpy(), a.numpy())
|
|
|
|
def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1):
|
|
a = Tensor.arange(16).reshape(4, 4).clone(device="CPU").realize()
|
|
a2 = mop(a)
|
|
expected = (a+a2).tolist()
|
|
a.assign(a+a2)
|
|
linear, var_vals = a.linear_with_vars()
|
|
kcount = len(linear.src)
|
|
run_linear(linear, var_vals)
|
|
self.assertListEqual(a.tolist(), expected)
|
|
self.assertEqual(kcount, expected_kcount)
|
|
def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2)
|
|
def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1)
|
|
|
|
def test_setitem_const_fused(self):
|
|
# https://github.com/tinygrad/tinygrad/issues/10690
|
|
a = Tensor.arange(16).clone().realize()
|
|
GlobalCounters.reset()
|
|
a[4] = 3
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
a.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(a.tolist(), [0, 1, 2, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
|
|
|
def test_no_extra_contiguous_on_setitem_assign_back(self):
|
|
# pattern: contiguous copy, advanced setitem, assign back (e.g. torch backend _view_write)
|
|
base = Tensor.arange(16).reshape(4, 4).clone()
|
|
flat_base = base.reshape(16).contiguous()
|
|
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
|
|
flat_base[idx] = Tensor([99,99,99,99])
|
|
base.assign(flat_base.reshape(4, 4))
|
|
sched = check_schedule(base, 4)
|
|
run_linear(*sched)
|
|
expected = list(range(16))
|
|
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
|
|
np.testing.assert_equal(base.reshape(16).numpy(), expected)
|
|
|
|
def test_const_folding_alt(self):
|
|
t = Tensor.full((2,), 1.)
|
|
lt = (t < 0.)
|
|
a = Tensor.empty(2).assign(t*lt.where(-1., 0.))
|
|
b = Tensor.empty(2, dtype=dtypes.bool).assign(lt)
|
|
Tensor.realize(a, b)
|
|
self.assertEqual(a.tolist(), [0., 0.])
|
|
self.assertEqual(b.tolist(), [False, False])
|
|
|
|
def test_self_assign_no_empty_kernel(self):
|
|
for shape in [(3, 3), (4, 4)]:
|
|
a = Tensor.ones(*shape).contiguous().realize()
|
|
a.assign(a / 1)
|
|
run_linear(*check_schedule(a, 0, filter_sink=False))
|
|
self.assertListEqual(a.tolist(), [[1.]*shape[1]]*shape[0])
|
|
|
|
def test_deviceless_materialize_localizes_to_target(self):
|
|
dev = "CPU" if Device.DEFAULT != "CPU" else "CPU:1"
|
|
t = Tensor.arange(Variable("s", 1, 128).bind(64)).cumsum().clone(dev)
|
|
self.assertEqual(t.device, dev)
|
|
np.testing.assert_equal(t[:64].numpy(), np.arange(64).cumsum())
|
|
|
|
class TestLimitBufs(unittest.TestCase):
|
|
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV", "crashes in ocelot")
|
|
def test_limit_bufs_with_var(self):
|
|
N = 31
|
|
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
|
bufs = [Tensor([1]*10).contiguous().realize() for i in range(N)]
|
|
|
|
vi = Variable("i", 0, 9).bind(1)
|
|
vj = Variable("j", 0, 9).bind(2)
|
|
root = bufs[0][vi] + bufs[0][vj]
|
|
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
|
|
self.assertEqual(root.item(), N * 2)
|
|
|
|
def test_limit_bufs_arange_condition(self):
|
|
# WHERE with arange-based condition (pure index math, no device) and many buffer loads should not crash limit_bufs
|
|
with Context(MAX_KERNEL_BUFFERS=8):
|
|
N = 8
|
|
idx = Tensor.arange(N)
|
|
base = Tensor.zeros(N)
|
|
for i in range(4):
|
|
a, b = Tensor.rand(N).realize(), Tensor.rand(N).realize()
|
|
base = (idx >= i).where(a + b, base)
|
|
assert all(x > 0 for x in base.tolist())
|
|
|
|
class TestSwizzle(unittest.TestCase):
|
|
def test_swizzle_simple(self):
|
|
Tensor.manual_seed(0)
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
|
a = Tensor.randint(32, 32).realize()
|
|
r = (a+a).sum(1).sum(0)
|
|
# double reduce collapses to a single reduce
|
|
run_linear(*check_schedule(r, 1))
|
|
self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0))
|
|
|
|
def test_single_swizzle(self):
|
|
Tensor.manual_seed(0)
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
|
a = Tensor.randint(4, 1).realize()
|
|
b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize()
|
|
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
|
|
r = a.sum(0)+b
|
|
run_linear(*check_schedule(r, 1))
|
|
self.assertEqual(r.numpy(), a.numpy().sum(0)+1)
|
|
|
|
def test_double_swizzle_possible(self):
|
|
Tensor.manual_seed(0)
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
|
a = Tensor.randint(4,).realize()
|
|
b = Tensor.randint(4,).realize()
|
|
# parallel reduce!
|
|
add = a.sum(0)+b.sum(0)
|
|
run_linear(*check_schedule(add, 1))
|
|
self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0))
|
|
|
|
def test_swizzle_reduceop(self):
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(4,4).realize()
|
|
y = Tensor.randn(4,4,4).realize()
|
|
out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y
|
|
run_linear(*check_schedule(out, 2)) # TODO: 1?
|
|
np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy())
|
|
|
|
def test_permute_rewrite(self):
|
|
x = Tensor.randn(4, 4, 16).realize()
|
|
y = Tensor.randn(4, 1, 16).realize()
|
|
z = Tensor.randn(4, 4, 1).realize()
|
|
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
|
|
run_linear(*check_schedule(t, 2)) # TODO: 1?
|
|
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
|
|
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
|
|
|
|
@unittest.skip("TODO: this swizzle isn't resolvable when there's a mask")
|
|
def test_swizzle_failure_permute(self):
|
|
a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90)
|
|
b = Tensor.empty(45,65)
|
|
a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,))
|
|
b_reduce = b.sum(axis=(0,))
|
|
t = a_reduce+b_reduce
|
|
run_linear(*check_schedule(t, 1))
|
|
|
|
def test_parallel_reduce_possible(self):
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(4, 2, 2).realize()
|
|
y = Tensor.randn(4, 2, 2).realize()
|
|
t = x.sum(axis=1)+y.sum(axis=1)
|
|
run_linear(*check_schedule(t, 1))
|
|
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
|
|
|
|
# kernels can only have 1 or n in each dim
|
|
def test_dont_parallelize_different_n(self):
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(4, 2, 2).realize()
|
|
y = Tensor.randn(4, 3, 2).realize()
|
|
t = x.sum(axis=1)+y.sum(axis=1)
|
|
run_linear(*check_schedule(t, 1))
|
|
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
|
|
|
|
def test_unsafe_pad(self):
|
|
x = Tensor.full((2,2), 1.0).contiguous()
|
|
y = x*x.sum((1,)).reciprocal()
|
|
t = y.pad(((0,1),None))
|
|
run_linear(*check_schedule(t, 3))
|
|
np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]])
|
|
|
|
zero_pm = UPat(Ops.CONST, arg=0)
|
|
class TestView(unittest.TestCase):
|
|
def test_all_masked_out(self):
|
|
# start with non CONST Ops
|
|
a = Tensor.rand(10, 10).realize()
|
|
# all masked out, degrades to const 0
|
|
b = a.pad(((0, 10), None))[10:]
|
|
sched = check_schedule(b.contiguous(), 1)
|
|
run_linear(*sched)
|
|
np.testing.assert_equal(b.numpy(), 0)
|
|
|
|
def test_mask_dim_1(self):
|
|
# mask out dim = 1 works too
|
|
a = Tensor.rand(10, 10).realize()
|
|
b = a.pad((None, (0, 10)))[:, 10:]
|
|
assert b.shape == (10, 10)
|
|
sched = check_schedule(b.contiguous(), 1)
|
|
run_linear(*sched)
|
|
np.testing.assert_equal(b.numpy(), 0)
|
|
|
|
def test_partial_mask(self):
|
|
# partial masked out does not degrade into CONST
|
|
a = Tensor.rand(10, 10).realize()
|
|
b = a.pad(((0, 5), None))[5:]
|
|
assert b.shape == (10, 10)
|
|
sched = check_schedule(b.contiguous(), 1)
|
|
run_linear(*sched)
|
|
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
|
|
|
# a*VIEW(x), where VIEW(x) = 0
|
|
# x collapses along with its children
|
|
def test_parent_view_collapses(self):
|
|
a = Tensor([1, 2])
|
|
b = Tensor.arange(3).clone()
|
|
bv = b.pad(((0, 2),))[-2:]
|
|
# this becomes a late a*0
|
|
late_mul = a*bv
|
|
run_linear(*check_schedule(late_mul, 2))
|
|
# the arange doesn't realize
|
|
#self.assertIsNone(b.uop.base.realized)
|
|
# mul doesn't realize
|
|
#self.assertIsNone(late_mul.uop.base.realized)
|
|
self.assertEqual(late_mul.tolist(), [0, 0])
|
|
|
|
# SINK has two branches:
|
|
# a*VIEW(x), where VIEW(x) = 0
|
|
# x+2
|
|
# as long as one child realizes, x does not collapse
|
|
def test_parent_multiple_children_no_collapse(self):
|
|
a = Tensor([1, 2])
|
|
b = Tensor.arange(3).clone()
|
|
bv = b.pad(((0, 2),))[-2:]
|
|
late_mul = a*bv
|
|
other_child = b+2
|
|
s = check_schedule([late_mul, other_child], 3)
|
|
# the arange becomes a BUFFER
|
|
self.assertIs(b.uop.base.op, Ops.BUFFER)
|
|
# NOTE: no longer checked
|
|
# mul still collapses
|
|
#self.assertIs(late_mul.uop.base.op, Ops.CONST)
|
|
run_linear(*s)
|
|
self.assertEqual(other_child.tolist(), [2, 3, 4])
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
|
|
class TestCopyFolding(unittest.TestCase):
|
|
def test_const_copy_is_free(self):
|
|
b = Tensor(1).to("CPU") * 4
|
|
run_linear(*check_schedule(b, 0, filter_sink=False))
|
|
assert b.item() == 4
|
|
|
|
def test_one_hot_with_copy(self):
|
|
y = Tensor([1, 2, 3]).to("CPU")
|
|
x = y.one_hot(10)
|
|
check_schedule(x, 3, filter_sink=False)
|
|
|
|
def test_late_const_copy_folding(self):
|
|
a = Tensor.arange(3).clone().realize()
|
|
zeros = Tensor.zeros(3, buffer=False).realize()
|
|
b = (a*zeros).to("CPU") + 1
|
|
run_linear(*check_schedule(b, 1, filter_sink=False))
|
|
self.assertListEqual(b.tolist(), [1, 1, 1])
|
|
self.assertEqual(b.device, "CPU")
|
|
|
|
def test_alu_after_copy(self):
|
|
a = Tensor.ones((4,)).to("CPU")
|
|
b = Tensor.empty(4, device="CPU")
|
|
add = a+b
|
|
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
|
add.schedule_linear()
|
|
|
|
def test_alu_before_copy(self):
|
|
buf = Tensor.ones(1).contiguous().realize()
|
|
a = buf+1
|
|
b = a.to("CPU")
|
|
self.assertListEqual(b.tolist(), [2.])
|
|
|
|
def test_copy_to_same_device(self):
|
|
a = Tensor.empty(4).uop
|
|
b = a.copy_to_device(a.device)
|
|
check_schedule(b, 1, filter_sink=False) # TODO: 0?
|
|
|
|
def test_copy_to_same_device_alt(self):
|
|
a = Tensor.empty(4, 4).uop
|
|
b = a.copy_to_device(a.device)
|
|
check_schedule(b, 1, filter_sink=False) # TODO: 0?
|
|
|
|
def test_copy_to_same_device_sched(self):
|
|
a = Tensor.ones(4).contiguous().realize().uop.buf_uop
|
|
t = Tensor(a.copy_to_device(a.device))
|
|
linear, var_vals = t.linear_with_vars()
|
|
assert len([call for call in linear.src if call.src[0].op is Ops.COPY]) == 0
|
|
run_linear(linear, var_vals)
|
|
assert t.uop.is_realized, f"didn't realize Tensor {t}"
|
|
self.assertListEqual(t.tolist(), [1.,1.,1.,1.])
|
|
|
|
def test_self_assign_same_device_copy(self):
|
|
a = Tensor.ones(4, 4).contiguous().realize()
|
|
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
|
|
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
|
|
run_linear(*check_schedule(a, 2, filter_sink=False))
|
|
self.assertListEqual(a.tolist(), [[1.]*4]*4)
|
|
|
|
def test_clone(self):
|
|
a = Tensor.empty(4)
|
|
check_schedule(a.clone(), 1, filter_sink=False)
|
|
|
|
def test_shrink_copy(self):
|
|
a = Tensor.arange(4)
|
|
view = a.shrink(((0, 2),))
|
|
b = view.clone()
|
|
run_linear(*check_schedule(b, 1, filter_sink=False))
|
|
self.assertEqual(b.uop.base.buffer.size, 2)
|
|
self.assertEqual(b.uop.numel(), 2)
|
|
self.assertListEqual(b.tolist(), [0, 1])
|
|
|
|
def test_expanded_copy(self):
|
|
a = Tensor.arange(2)
|
|
view = a.reshape(2, 1).expand(2, 2)
|
|
b = view.clone()
|
|
run_linear(*check_schedule(b, 1, filter_sink=False))
|
|
self.assertEqual(b.uop.base.buffer.size, 4)
|
|
self.assertEqual(b.uop.numel(), 4)
|
|
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
|
|
|
|
def test_permuted_copy(self):
|
|
a = Tensor.arange(4)
|
|
b = a.reshape(2, 2).permute(1, 0)
|
|
b.realize()
|
|
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
|
|
|
def test_permute_on_disk(self):
|
|
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).clone().realize().uop.base.buffer.as_memoryview())
|
|
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
|
b = a.reshape(2, 2).permute(1, 0).to("CPU")
|
|
b.realize()
|
|
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
|
|
|
def test_permute_on_disk_contiguous(self):
|
|
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).clone().realize().uop.base.buffer.as_memoryview())
|
|
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
|
|
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
|
|
b.realize()
|
|
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
|
|
|
def test_permute_after_shrink(self):
|
|
a = Tensor.arange(5)
|
|
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
|
|
b.realize()
|
|
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
|
|
|
# NOTE: disk permute must come after COPY
|
|
def test_permute_after_shrink_on_disk(self):
|
|
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).clone().realize().uop.base.buffer.as_memoryview())
|
|
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
|
|
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
|
|
b.realize()
|
|
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
|
|
|
class TestUOpBecome(unittest.TestCase):
|
|
def test_setitem_offset(self):
|
|
a = Tensor.full((16,), 0.).contiguous().realize()
|
|
b = Tensor.full((16,), 1.).contiguous().realize()
|
|
a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,))
|
|
b.shrink(((0,4),)).assign(a_view).realize()
|
|
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
|
|
|
class TestFusionOp(unittest.TestCase):
|
|
def test_contiguous_add(self):
|
|
def test(contig=False):
|
|
bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
|
|
x = bt.permute(1,0)
|
|
if contig: x = x.contiguous()
|
|
return (x.permute(1,0) + bt).data()
|
|
assert test() == test(True)
|
|
|
|
def test_expand_fuse(self):
|
|
bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
|
|
out = (bt*2).expand(10,10).sum(1)
|
|
run_linear(*out.linear_with_vars())
|
|
outd = out.tolist()
|
|
assert all(x == 20.0 for x in outd)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|