make full with buffer=False deviceless (#16483)

affects arange and eye
This commit is contained in:
chenyu 2026-06-03 12:35:59 -04:00 committed by GitHub
commit 8a4203638a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 39 additions and 47 deletions

View file

@ -569,8 +569,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.__rshift__.Scalar": lambda x,y: x>>y,
"aten.__irshift__.Scalar": lambda x,y: x>>y,
# inplace ops using replace for fusion
"aten.zero_": lambda x: x.const_like(0),
"aten.fill_.Scalar": lambda x, y: x.const_like(y),
"aten.zero_": lambda x: Tensor.full(x.shape, 0, dtype=x.dtype, device=x.device, buffer=False),
"aten.fill_.Scalar": lambda x, y: Tensor.full(x.shape, y, dtype=x.dtype, device=x.device, buffer=False),
"aten.add_.Tensor": lambda self, other, alpha=1.0: self + other * alpha,
"aten.add_.Scalar": lambda self, other, alpha=1.0: self + other * alpha,
"aten.mul_.Tensor": lambda self, other: self * other,
@ -622,7 +622,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.asinh": Tensor.asinh,
"aten.mul": Tensor.mul,
"aten.atanh": Tensor.atanh,
"aten.fill_.Tensor": lambda self, value: self.const_like(value.reshape(()).item()),
"aten.fill_.Tensor": lambda self, value: Tensor.full(self.shape, value.reshape(()).item(), dtype=self.dtype, device=self.device, buffer=False),
"aten.flip": Tensor.flip,
"aten.scatter_reduce.two": Tensor.scatter_reduce,
"aten.squeeze_.dim": Tensor.squeeze,

View file

@ -16,8 +16,8 @@ class TestArange(unittest.TestCase):
return estimate_uop(linear.src[-1]).ops
def test_arange_complexity(self):
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
self.assertEqual(self._get_flops(Tensor.arange(256).clone(), np.arange(256)), 0)
self.assertEqual(self._get_flops(Tensor.arange(2560).clone(), np.arange(2560)), 0)
@unittest.skipIf(Device.DEFAULT == "CL", "flaky in CI")
def test_arange_cumsum(self):
@ -30,7 +30,7 @@ class TestArange(unittest.TestCase):
def test_eye_complexity(self):
with Context(NOOPT=1):
# NOTE: not every backend supports CMPEQ
self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560)
self.assertLessEqual(self._get_flops(Tensor.eye(2560).clone(), np.eye(2560)), 2*2560*2560)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexing is weird")
def test_tri_complexity(self):

View file

@ -311,12 +311,12 @@ class TestLinearizer(unittest.TestCase):
assert len(reg_stores) == 0, "STORE to reg should have been simplified"
assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)
helper(Tensor.arange(5.5, (3.5*300), 3.5).clone(), max_ops=2)
helper(Tensor.arange(-1, -100, -5).clone(), max_ops=2)
# NOTE: both of these split the reduce (this just wasn't tracked before)
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
#helper(Tensor.arange(256), max_ops=2)
helper(Tensor.arange(255), max_ops=2)
helper(Tensor.arange(255).clone(), max_ops=2)
@unittest.skip("test implicitly depends on certain optimizations")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")

View file

@ -71,7 +71,7 @@ class TestSchedule(unittest.TestCase):
def test_arange_avgpool2d(self, kcount=1):
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
t = x.avg_pool2d(padding=1)
t = x.avg_pool2d(padding=1).clone()
linear, var_vals = t.linear_with_vars()
self.assertEqual(len(linear.src), kcount)
run_linear(linear, var_vals)
@ -85,17 +85,17 @@ class TestSchedule(unittest.TestCase):
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
# all permutes, reshapes, expands and shrinks push through the reduce
def test_arange_sum(self):
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
a = Tensor.arange(6).reshape(3, 2).sum(axis=1).clone()
run_linear(*check_schedule(a, 1))
self.assertListEqual(a.tolist(), [1, 5, 9])
def test_arange_sum_alt(self):
a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2)
a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2).clone()
run_linear(*check_schedule(a, 1))
np.testing.assert_equal(a.numpy(), 20)
def test_permute_arange(self):
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1).clone()
run_linear(*check_schedule(a, 1))
self.assertListEqual(a.tolist(), [[15]])
@ -230,7 +230,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_same(self):
x = Tensor.eye(64)
x = Tensor.eye(64).clone().realize()
z = x.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
@ -847,7 +847,7 @@ class TestSchedule(unittest.TestCase):
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32, buffer=False)
casted_view = a.cast(dtypes.int32)
run_linear(*check_schedule(casted_view, 1))
run_linear(*check_schedule(casted_view, 0))
realized_const_view = casted_view.contiguous()
run_linear(*check_schedule(realized_const_view, 0))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@ -927,7 +927,7 @@ class TestSchedule(unittest.TestCase):
def test_div_padded_arange(self):
x = Tensor.full((2,2), 16, buffer=False)
y = x.div(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2), rounding_mode="trunc").pad(((1,1), (1,1)))
out = y.sum(axis=1)
out = y.sum(axis=1).clone()
run_linear(*check_schedule(out, 1))
self.assertListEqual(out.tolist(), [0, 12, 4, 0])
@ -976,10 +976,10 @@ class TestSchedule(unittest.TestCase):
def test_precompute_freqs_cis(self):
from extra.models.llama import precompute_freqs_cis
args = {"dim":32, "end":2048, "theta":10000}
fused = precompute_freqs_cis(**args)
fused = precompute_freqs_cis(**args).clone()
run_linear(*check_schedule(fused, 1))
if getenv("CHECK", 1):
ref = precompute_freqs_cis(**args)
ref = precompute_freqs_cis(**args).clone()
run_linear(*check_schedule(ref, 1))
np.testing.assert_equal(fused.numpy(), ref.numpy())

View file

@ -68,7 +68,7 @@ class TestFuse(unittest.TestCase):
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
def test_fuse_arange_eye(self):
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
self._test_fuse(lambda: (Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10)).clone())
@unittest.skip("needs RANGEIFY>1")
def test_double_gemm(self):

View file

@ -30,7 +30,7 @@ class TestAttention(unittest.TestCase):
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
assert_jit_cache_len(rope_prune, 1)
assert_jit_cache_len(rope_noprune, 3)
assert_jit_cache_len(rope_noprune, 2)
if __name__ == '__main__':
unittest.main()

View file

@ -58,14 +58,14 @@ class TestMultiRamUsage(unittest.TestCase):
X = Tensor.ones(256, buffer=False).realize()
self.assertUsed(0)
X.shard_(devices_4).realize()
self.assertUsed(256 * 4 * 4) # TODO: can be zero
self.assertUsed(0)
def test_sharded_memory_axis_const(self):
devices_4 = tuple(f"NULL:{i+1}" for i in range(4))
X = Tensor.ones(256, buffer=False).realize()
self.assertUsed(0)
X.shard_(devices_4, axis=0).realize()
self.assertUsed(256 * 4) # TODO: can be zero
self.assertUsed(0)
def test_zeros_per_device(self):
_ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize()

View file

@ -102,7 +102,7 @@ class TestIdxUpcast(unittest.TestCase):
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "PTX and NIR always converts Ops.INDEX to int64")
def test_symfold(self):
# This would cause an overflow, but after sym fold it's within int32
a = Tensor.arange(65535)
a = Tensor.arange(65535).clone()
uops = self._schedule_render(a)
assert all(uop.dtype is not dtypes.long for uop in uops)

View file

@ -323,7 +323,7 @@ class TestKernelOpts(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_arange_opts(self):
a = Tensor.arange(128)
a = Tensor.arange(128).clone()
# NOTE: arange no longer has reduce ops available for opt
helper_linearizer_opt(a, [
#[Opt(OptOps.GROUP, 0, 32)],

View file

@ -83,7 +83,7 @@ class TestLinAlg(unittest.TestCase):
reconstruction_helper([U, s_diag, V], a)
def test_svd_identity_4x4(self):
a = Tensor.eye(4)
a = Tensor.eye(4).clone()
U,S,V = a.svd()
assert not np.isnan(U.numpy()).any()
assert not np.isnan(S.numpy()).any()

View file

@ -9,11 +9,8 @@ class TestSetitemInto(unittest.TestCase):
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 16)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_mem, 0)
self.assertListEqual(t.tolist(), [[0, 1], [5, 5]])
def test_setitem_into_unrealized_sliced_compute(self):
@ -25,8 +22,8 @@ class TestSetitemInto(unittest.TestCase):
w[1] = 99
self.assertEqual(GlobalCounters.kernel_count, 0)
w.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*4)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_mem, 0)
self.assertListEqual(w.tolist(), [4, 99, 8, 10])
def test_setitem_into_empty(self):
@ -87,11 +84,8 @@ class TestSetitemInto(unittest.TestCase):
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_mem, 0)
self.assertListEqual(t.tolist(), [1, 5, 1, 1])
def test_setitem_into_const_alu(self):
@ -100,11 +94,8 @@ class TestSetitemInto(unittest.TestCase):
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_mem, 0)
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
def test_setitem_into_arange(self):
@ -116,7 +107,7 @@ class TestSetitemInto(unittest.TestCase):
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
def test_setitem_slice_const(self):

View file

@ -40,7 +40,7 @@ class _function(Generic[ReturnType]):
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
# deduplicate input_uops, keeping the first occurrence index for each unique uop
call_uops: list[UOp] = dedup([u for t in params if not ((u:=(t.uop if isinstance(t, Tensor) else t)).base.op is Ops.CONST and u.device is None)])
call_uops: list[UOp] = dedup([u for t in params if (u:=(t.uop if isinstance(t, Tensor) else t)).device is not None])
# disable realize/schedule while this is running
# run it and do surgery later

View file

@ -42,7 +42,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
if isinstance(fill_value, UOp): val = cls.const(dt or fill_value.dtype, fill_value)
else: val = cls.const(dt or dtypes.from_py(fill_value), fill_value, None if buffer else canonicalize_device(device))
else: val = cls.const(dt or dtypes.from_py(fill_value), fill_value, canonicalize_device(device) if device is not None else None)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val

View file

@ -6,6 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, g
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
from tinygrad.device import canonicalize_device
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
@ -71,8 +72,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
else:
# the Bufferize before a COPY is not removable. there should be a better way to do this
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
# LOCAL: None in the device assigns it a number later
opts = BufferizeOpts(device=canonicalize_device(s.device), removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
new_src = UOp(Ops.STAGE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])