mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
405866f2b7
commit
8a4203638a
14 changed files with 39 additions and 47 deletions
|
|
@ -569,8 +569,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
|||
"aten.__rshift__.Scalar": lambda x,y: x>>y,
|
||||
"aten.__irshift__.Scalar": lambda x,y: x>>y,
|
||||
# inplace ops using replace for fusion
|
||||
"aten.zero_": lambda x: x.const_like(0),
|
||||
"aten.fill_.Scalar": lambda x, y: x.const_like(y),
|
||||
"aten.zero_": lambda x: Tensor.full(x.shape, 0, dtype=x.dtype, device=x.device, buffer=False),
|
||||
"aten.fill_.Scalar": lambda x, y: Tensor.full(x.shape, y, dtype=x.dtype, device=x.device, buffer=False),
|
||||
"aten.add_.Tensor": lambda self, other, alpha=1.0: self + other * alpha,
|
||||
"aten.add_.Scalar": lambda self, other, alpha=1.0: self + other * alpha,
|
||||
"aten.mul_.Tensor": lambda self, other: self * other,
|
||||
|
|
@ -622,7 +622,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
|||
"aten.asinh": Tensor.asinh,
|
||||
"aten.mul": Tensor.mul,
|
||||
"aten.atanh": Tensor.atanh,
|
||||
"aten.fill_.Tensor": lambda self, value: self.const_like(value.reshape(()).item()),
|
||||
"aten.fill_.Tensor": lambda self, value: Tensor.full(self.shape, value.reshape(()).item(), dtype=self.dtype, device=self.device, buffer=False),
|
||||
"aten.flip": Tensor.flip,
|
||||
"aten.scatter_reduce.two": Tensor.scatter_reduce,
|
||||
"aten.squeeze_.dim": Tensor.squeeze,
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ class TestArange(unittest.TestCase):
|
|||
return estimate_uop(linear.src[-1]).ops
|
||||
|
||||
def test_arange_complexity(self):
|
||||
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
|
||||
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
|
||||
self.assertEqual(self._get_flops(Tensor.arange(256).clone(), np.arange(256)), 0)
|
||||
self.assertEqual(self._get_flops(Tensor.arange(2560).clone(), np.arange(2560)), 0)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CL", "flaky in CI")
|
||||
def test_arange_cumsum(self):
|
||||
|
|
@ -30,7 +30,7 @@ class TestArange(unittest.TestCase):
|
|||
def test_eye_complexity(self):
|
||||
with Context(NOOPT=1):
|
||||
# NOTE: not every backend supports CMPEQ
|
||||
self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560)
|
||||
self.assertLessEqual(self._get_flops(Tensor.eye(2560).clone(), np.eye(2560)), 2*2560*2560)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexing is weird")
|
||||
def test_tri_complexity(self):
|
||||
|
|
|
|||
|
|
@ -311,12 +311,12 @@ class TestLinearizer(unittest.TestCase):
|
|||
assert len(reg_stores) == 0, "STORE to reg should have been simplified"
|
||||
assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
|
||||
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
|
||||
helper(Tensor.arange(-1, -100, -5), max_ops=2)
|
||||
helper(Tensor.arange(5.5, (3.5*300), 3.5).clone(), max_ops=2)
|
||||
helper(Tensor.arange(-1, -100, -5).clone(), max_ops=2)
|
||||
# NOTE: both of these split the reduce (this just wasn't tracked before)
|
||||
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
|
||||
#helper(Tensor.arange(256), max_ops=2)
|
||||
helper(Tensor.arange(255), max_ops=2)
|
||||
helper(Tensor.arange(255).clone(), max_ops=2)
|
||||
|
||||
@unittest.skip("test implicitly depends on certain optimizations")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_arange_avgpool2d(self, kcount=1):
|
||||
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
|
||||
t = x.avg_pool2d(padding=1)
|
||||
t = x.avg_pool2d(padding=1).clone()
|
||||
linear, var_vals = t.linear_with_vars()
|
||||
self.assertEqual(len(linear.src), kcount)
|
||||
run_linear(linear, var_vals)
|
||||
|
|
@ -85,17 +85,17 @@ class TestSchedule(unittest.TestCase):
|
|||
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
|
||||
# all permutes, reshapes, expands and shrinks push through the reduce
|
||||
def test_arange_sum(self):
|
||||
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
|
||||
a = Tensor.arange(6).reshape(3, 2).sum(axis=1).clone()
|
||||
run_linear(*check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [1, 5, 9])
|
||||
|
||||
def test_arange_sum_alt(self):
|
||||
a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2)
|
||||
a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2).clone()
|
||||
run_linear(*check_schedule(a, 1))
|
||||
np.testing.assert_equal(a.numpy(), 20)
|
||||
|
||||
def test_permute_arange(self):
|
||||
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
|
||||
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1).clone()
|
||||
run_linear(*check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [[15]])
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ class TestSchedule(unittest.TestCase):
|
|||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_same(self):
|
||||
x = Tensor.eye(64)
|
||||
x = Tensor.eye(64).clone().realize()
|
||||
z = x.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
|
|
@ -847,7 +847,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32, buffer=False)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
run_linear(*check_schedule(casted_view, 1))
|
||||
run_linear(*check_schedule(casted_view, 0))
|
||||
realized_const_view = casted_view.contiguous()
|
||||
run_linear(*check_schedule(realized_const_view, 0))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
|
@ -927,7 +927,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_div_padded_arange(self):
|
||||
x = Tensor.full((2,2), 16, buffer=False)
|
||||
y = x.div(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2), rounding_mode="trunc").pad(((1,1), (1,1)))
|
||||
out = y.sum(axis=1)
|
||||
out = y.sum(axis=1).clone()
|
||||
run_linear(*check_schedule(out, 1))
|
||||
self.assertListEqual(out.tolist(), [0, 12, 4, 0])
|
||||
|
||||
|
|
@ -976,10 +976,10 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_precompute_freqs_cis(self):
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
args = {"dim":32, "end":2048, "theta":10000}
|
||||
fused = precompute_freqs_cis(**args)
|
||||
fused = precompute_freqs_cis(**args).clone()
|
||||
run_linear(*check_schedule(fused, 1))
|
||||
if getenv("CHECK", 1):
|
||||
ref = precompute_freqs_cis(**args)
|
||||
ref = precompute_freqs_cis(**args).clone()
|
||||
run_linear(*check_schedule(ref, 1))
|
||||
np.testing.assert_equal(fused.numpy(), ref.numpy())
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class TestFuse(unittest.TestCase):
|
|||
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
|
||||
|
||||
def test_fuse_arange_eye(self):
|
||||
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
|
||||
self._test_fuse(lambda: (Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10)).clone())
|
||||
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_double_gemm(self):
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestAttention(unittest.TestCase):
|
|||
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||||
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||||
assert_jit_cache_len(rope_prune, 1)
|
||||
assert_jit_cache_len(rope_noprune, 3)
|
||||
assert_jit_cache_len(rope_noprune, 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -58,14 +58,14 @@ class TestMultiRamUsage(unittest.TestCase):
|
|||
X = Tensor.ones(256, buffer=False).realize()
|
||||
self.assertUsed(0)
|
||||
X.shard_(devices_4).realize()
|
||||
self.assertUsed(256 * 4 * 4) # TODO: can be zero
|
||||
self.assertUsed(0)
|
||||
|
||||
def test_sharded_memory_axis_const(self):
|
||||
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
|
||||
X = Tensor.ones(256, buffer=False).realize()
|
||||
self.assertUsed(0)
|
||||
X.shard_(devices_4, axis=0).realize()
|
||||
self.assertUsed(256 * 4) # TODO: can be zero
|
||||
self.assertUsed(0)
|
||||
|
||||
def test_zeros_per_device(self):
|
||||
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "PTX and NIR always converts Ops.INDEX to int64")
|
||||
def test_symfold(self):
|
||||
# This would cause an overflow, but after sym fold it's within int32
|
||||
a = Tensor.arange(65535)
|
||||
a = Tensor.arange(65535).clone()
|
||||
uops = self._schedule_render(a)
|
||||
assert all(uop.dtype is not dtypes.long for uop in uops)
|
||||
|
||||
|
|
|
|||
|
|
@ -323,7 +323,7 @@ class TestKernelOpts(unittest.TestCase):
|
|||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_arange_opts(self):
|
||||
a = Tensor.arange(128)
|
||||
a = Tensor.arange(128).clone()
|
||||
# NOTE: arange no longer has reduce ops available for opt
|
||||
helper_linearizer_opt(a, [
|
||||
#[Opt(OptOps.GROUP, 0, 32)],
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class TestLinAlg(unittest.TestCase):
|
|||
reconstruction_helper([U, s_diag, V], a)
|
||||
|
||||
def test_svd_identity_4x4(self):
|
||||
a = Tensor.eye(4)
|
||||
a = Tensor.eye(4).clone()
|
||||
U,S,V = a.svd()
|
||||
assert not np.isnan(U.numpy()).any()
|
||||
assert not np.isnan(S.numpy()).any()
|
||||
|
|
|
|||
|
|
@ -9,11 +9,8 @@ class TestSetitemInto(unittest.TestCase):
|
|||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 16)
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertEqual(GlobalCounters.global_mem, 0)
|
||||
self.assertListEqual(t.tolist(), [[0, 1], [5, 5]])
|
||||
|
||||
def test_setitem_into_unrealized_sliced_compute(self):
|
||||
|
|
@ -25,8 +22,8 @@ class TestSetitemInto(unittest.TestCase):
|
|||
w[1] = 99
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
w.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertEqual(GlobalCounters.global_mem, 0)
|
||||
self.assertListEqual(w.tolist(), [4, 99, 8, 10])
|
||||
|
||||
def test_setitem_into_empty(self):
|
||||
|
|
@ -87,11 +84,8 @@ class TestSetitemInto(unittest.TestCase):
|
|||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertEqual(GlobalCounters.global_mem, 0)
|
||||
self.assertListEqual(t.tolist(), [1, 5, 1, 1])
|
||||
|
||||
def test_setitem_into_const_alu(self):
|
||||
|
|
@ -100,11 +94,8 @@ class TestSetitemInto(unittest.TestCase):
|
|||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertEqual(GlobalCounters.global_mem, 0)
|
||||
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
|
||||
|
||||
def test_setitem_into_arange(self):
|
||||
|
|
@ -116,7 +107,7 @@ class TestSetitemInto(unittest.TestCase):
|
|||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
|
||||
|
||||
def test_setitem_slice_const(self):
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class _function(Generic[ReturnType]):
|
|||
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
|
||||
|
||||
# deduplicate input_uops, keeping the first occurrence index for each unique uop
|
||||
call_uops: list[UOp] = dedup([u for t in params if not ((u:=(t.uop if isinstance(t, Tensor) else t)).base.op is Ops.CONST and u.device is None)])
|
||||
call_uops: list[UOp] = dedup([u for t in params if (u:=(t.uop if isinstance(t, Tensor) else t)).device is not None])
|
||||
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
new_shape = argfix(shape)
|
||||
dt = to_dtype(dtype) if dtype is not None else None
|
||||
if isinstance(fill_value, UOp): val = cls.const(dt or fill_value.dtype, fill_value)
|
||||
else: val = cls.const(dt or dtypes.from_py(fill_value), fill_value, None if buffer else canonicalize_device(device))
|
||||
else: val = cls.const(dt or dtypes.from_py(fill_value), fill_value, canonicalize_device(device) if device is not None else None)
|
||||
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
||||
return val.clone(device=device) if buffer else val
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, g
|
|||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
|
||||
from tinygrad.device import canonicalize_device
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
|
|
@ -71,8 +72,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
else:
|
||||
# the Bufferize before a COPY is not removable. there should be a better way to do this
|
||||
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
|
||||
# LOCAL: None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=canonicalize_device(s.device), removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
|
||||
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
|
||||
new_src = UOp(Ops.STAGE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue