Merge branch 'master' into shrink_in_render

This commit is contained in:
George Hotz 2026-06-01 14:43:17 -07:00 committed by GitHub
commit 46541d70f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 95 additions and 81 deletions

View file

@ -25,10 +25,12 @@ def calculate_storage_offset(x: Tensor) -> int:
u_strides = strides_for_shape(u.src[0].shape)
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
return offset
def wrap(x: Tensor) -> torch.Tensor:
def wrap(x: Tensor, dev: torch.device|None=None) -> torch.Tensor:
x._strides = strides_for_shape(x.shape) # always recalculate
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
# a deviceless tinygrad value takes the device from the op context
idx = _to_torch_device(x.device).index if x.device is not None else (dev.index if dev is not None else 0)
return mod.wrap(x, _to_torch_dtype(x.dtype), idx)
def _update_torch_metadata(tensor: torch.Tensor, tiny: Tensor) -> None:
tiny._strides = strides_for_shape(tiny.shape)
tiny._storage_offset = calculate_storage_offset(tiny)
@ -545,14 +547,17 @@ def wrap_out(f):
assigned = f(*args, **kwargs)
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
assert out.device == assigned.device or out.device is None or assigned.device is None, f"device mismatch: {assigned.device} -> {out.device}"
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
if out.device is None and assigned.device is not None: out.replace(out.empty_like(device=assigned.device))
return out.assign(assigned)
return _wrap_out
def _inplace_op(t, new_value):
if not hasattr(t, "_view_base") and not getattr(canonical_base(t), "_views", set()): t.replace(new_value)
else: _apply_inplace(t, new_value)
else:
if (base:=canonical_base(t)).device is None and new_value.device is not None: base.replace(base.empty_like(device=new_value.device))
_apply_inplace(t, new_value)
return t
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
@ -679,10 +684,11 @@ def wrap_fxn(k,f):
if TORCH_DEBUG:
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
dev = next((a.device for a in args if isinstance(a, torch.Tensor) and a.device.type == "tiny"), None)
args, kwargs = unwrap_args(args, kwargs)
out = f(*args, **kwargs)
if isinstance(out, Tensor): return wrap(out)
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
if isinstance(out, Tensor): return wrap(out, dev)
elif isinstance(out, tuple): return tuple(wrap(x, dev) for x in out)
else: raise RuntimeError(f"unknown output type {type(out)}")
return nf

View file

@ -22,30 +22,30 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None):
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
dtype = alu_src_uops[0].dtype
a = UOp.param(0, dtype.ptr())
b = UOp.param(1, dtype.ptr())
a = UOp.param(0, dtype.ptr(1))
b = UOp.param(1, dtype.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld = b.index(idx)
ld = b.index(idx, ptr=True).load()
alu = ld.alu(alu_op, *alu_src_uops)
store = UOp.store(a.index(idx), alu)
store = UOp.store(a.index(idx, ptr=True), alu)
return _test_uop_result([Tensor([input_val])], UOp(Ops.SINK, dtypes.void, (store,), arg=KernelInfo()))[0]
class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu(self):
a = UOp.param(0, dtypes.int.ptr())
a = UOp.param(0, dtypes.int.ptr(4))
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu), ptr=True), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu_2d(self):
a = UOp.param(0, dtypes.int.ptr())
a = UOp.param(0, dtypes.int.ptr(8))
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1), ptr=True), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
ret = _test_uop_result([], sink, local_size=[4, 2, 1])[0]
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])

View file

@ -27,8 +27,8 @@ def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(1), (), i+1) for i,dtype in enumerate(dts)]
loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts))
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
@ -42,7 +42,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, op, output_dtype, loads)
out = buf_store[UOp.const(dtypes.int32, 0)].store(alu)
@ -54,7 +54,7 @@ def _test_single_value_const(vals, op, dts):
def _test_uops_result(output_dtype, uops, res):
# uops = []
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
# res = output_fn(uops)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
@ -221,8 +221,8 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
g1 = UOp.param(0, dtypes.int32.ptr())
out = UOp.param(1, dtypes.int32.ptr())
g1 = UOp.param(0, dtypes.int32.ptr(3))
out = UOp.param(1, dtypes.int32.ptr(2))
c1 = UOp.const(dtypes.int, 2)
c2 = UOp.const(dtypes.int, 3)
l1 = g1.index(c1)
@ -249,7 +249,7 @@ class TestAssembly(unittest.TestCase):
self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
def test_mulacc_shl(self):
g1 = UOp.param(0, dtypes.int32.ptr())
g1 = UOp.param(0, dtypes.int32.ptr(2))
c1 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 1)
expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2)
@ -258,7 +258,7 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.MULACC, [x.op for x in uops])
def test_use_cmpeq(self):
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(8))
c = UOp.const(dtypes.uint, 7)
comp = g.index(c).ne(c).ne(True)
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)

View file

@ -82,8 +82,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple
for buf_dt, data in inputs or []:
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp.param(0, uop.dtype.ptr())
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
g = UOp.param(0, uop.dtype.ptr(1))
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0), ptr=True), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals)
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]

View file

@ -968,7 +968,7 @@ class TestSchedule(unittest.TestCase):
def test_const_schedule_contig(self):
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
check_schedule(constv, 1)
check_schedule(constv, 0)
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)

View file

@ -10,8 +10,8 @@ class TestTranscendentalFunctions(unittest.TestCase):
def test_payne_hanek_reduction(self):
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
# Load input value from a buffer to prevent constant folding
input_buf = UOp.param(1, dtypes.double.ptr())
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
input_buf = UOp.param(1, dtypes.double.ptr(1))
loaded_value = input_buf.index(UOp.const(dtypes.int, 0), ptr=True).load()
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))

View file

@ -380,22 +380,22 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
def test_cast_alu_fold(self):
d0 = UOp.param(0, dtypes.bool.ptr())
d1 = UOp.param(1, dtypes.int.ptr())
d0 = UOp.param(0, dtypes.bool.ptr(1))
d1 = UOp.param(1, dtypes.int.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld = d1.index(idx)
alu = (ld<1).cast(dtypes.bool)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
def test_double_cast_fold(self):
d0 = UOp.param(0, dtypes.float.ptr())
d1 = UOp.param(1, dtypes.int.ptr())
d0 = UOp.param(0, dtypes.float.ptr(1))
d1 = UOp.param(1, dtypes.int.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld = d1.index(idx)
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase):
def test_bitcast_to_same_dtype_fold(self):
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
d0 = UOp.param(0, dt.ptr())
d0 = UOp.param(0, dt.ptr(1))
v = d0.index(UOp.const(dtypes.int, 0))
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
@ -427,18 +427,18 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.long.ptr())
d0 = UOp.param(0, dtypes.long.ptr(100))
ld = d0.index(ridx0.valid(ridx0<50))
w = (ridx0<50).where(ld, 5)
out = UOp.param(1, dtypes.long.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.long.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.long.ptr())
d0 = UOp.param(0, dtypes.long.ptr(100))
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
w = (ridx0<50).where(5, ld)
uops = to_uops_list([w])
@ -448,40 +448,40 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_with_cast(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.int.ptr())
d0 = UOp.param(0, dtypes.int.ptr(100))
gate_idx = ridx0.valid((ridx0<50))
ld = d0.index(gate_idx).cast(dtypes.float)
w = (ridx0<50).where(ld, 5.0)
out = UOp.param(1, dtypes.float.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.float.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
def test_where_on_casted_gated_load_extra_cond(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.float.ptr())
d0 = UOp.param(0, dtypes.float.ptr(100))
ld = d0.index(ridx0.valid(ridx0<50))
w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half)
out = UOp.param(1, dtypes.half.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.half.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
def test_where_on_casted_gated_load_extra_cond_swapped(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.float.ptr())
d0 = UOp.param(0, dtypes.float.ptr(100))
ld = d0.index(ridx0.valid(ridx0<50))
w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half)
out = UOp.param(1, dtypes.half.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.half.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
def test_where_in_store_becomes_gate(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.long.ptr())
idx = d0.index(ridx0)
d0 = UOp.param(0, dtypes.long.ptr(100))
idx = d0.index(ridx0, ptr=True)
ld = idx.load()
val = (ridx0<50).where(5, ld)
st = idx.store(val).end(ridx0)
@ -529,33 +529,33 @@ class TestUOpGraph(unittest.TestCase):
self.assertNotEqual(u.dtype, dtypes.long)
def test_fold_gated_load(self):
glbl0 = UOp.param(0, dtypes.int.ptr())
glbl1 = UOp.param(1, dtypes.int.ptr())
glbl2 = UOp.param(2, dtypes.int.ptr())
glbl0 = UOp.param(0, dtypes.int.ptr(1))
glbl1 = UOp.param(1, dtypes.int.ptr(1))
glbl2 = UOp.param(2, dtypes.int.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld0 = glbl1.index(UOp.invalid())
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx, ptr=True), ld1+ld0))])
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp.param(0, dtypes.int.ptr())
glbl0 = UOp.param(0, dtypes.int.ptr(16))
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = smem.after(barrier).index(UOp.invalid())
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx, ptr=True), ld1+ld0))])
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
def test_fold_gated_store(self):
glbl = UOp.param(0, dtypes.int.ptr())
glbl = UOp.param(0, dtypes.int.ptr(1))
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)

View file

@ -951,8 +951,8 @@ class TestSymbolic(unittest.TestCase):
expr = cond.where(a, b).cast(dtypes.half)
# TODO: copied from render, render does not support cast
glbl = UOp.param(0, dtypes.int.ptr())
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
glbl = UOp.param(0, dtypes.int.ptr(1))
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))

View file

@ -110,10 +110,10 @@ class TestExecALU(unittest.TestCase):
class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self):
gmem = UOp.param(0, dtypes.float.ptr())
gmem = UOp.param(0, dtypes.float.ptr(8))
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
gate = gidx0<UOp.const(dtypes.int, 1)
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
idx = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
val = UOp.const(dtypes.float, 42.0)
store = UOp(Ops.STORE, dtypes.void, (idx, val))
uops = to_uops_list([store])
@ -126,12 +126,12 @@ class TestGatedStoreRewrite(unittest.TestCase):
self.assertEqual(len(gated_uops[-1].src), 2)
def test_gate_some_stores(self):
gmem0 = UOp.param(0, dtypes.float.ptr())
gmem1 = UOp.param(1, dtypes.float.ptr())
gmem0 = UOp.param(0, dtypes.float.ptr(8))
gmem1 = UOp.param(1, dtypes.float.ptr(8))
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)
@ -146,13 +146,13 @@ class TestGatedStoreRewrite(unittest.TestCase):
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.skip("we don't merge ifs anymore")
def test_merge_ifs_alt(self):
gmem0 = UOp.param(0, dtypes.float.ptr())
gmem1 = UOp.param(1, dtypes.float.ptr())
gmem0 = UOp.param(0, dtypes.float.ptr(8))
gmem1 = UOp.param(1, dtypes.float.ptr(8))
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0*UOp.const(dtypes.int, 2)
gate = gidx0<UOp.const(dtypes.int, 1)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gate)))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx.valid(gate)))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gate)))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx.valid(gate)))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)
@ -170,7 +170,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
class TestFastIdiv(unittest.TestCase):
def test_division_power_of_two(self):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp.param(0, dt.ptr())
g = UOp.param(0, dt.ptr(3))
c = UOp.const(dt, 2)
l = g.index(c)
a = UOp(Ops.CDIV, dt, (l, c))
@ -183,7 +183,7 @@ class TestFastIdiv(unittest.TestCase):
def test_floormod_power_of_two(self):
# FLOORMOD by a power of two lowers to AND (correct floor mod for any sign in two's complement)
for dt in (dtypes.int32, dtypes.uint32):
g = UOp.param(0, dt.ptr())
g = UOp.param(0, dt.ptr(9))
c = UOp.const(dt, 8)
a = UOp(Ops.FLOORMOD, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
@ -195,7 +195,7 @@ class TestFastIdiv(unittest.TestCase):
def test_floordiv_power_of_two_uint(self):
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
for dt in (dtypes.uint32, dtypes.uint64):
g = UOp.param(0, dt.ptr())
g = UOp.param(0, dt.ptr(3))
c = UOp.const(dt, 2)
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
@ -207,7 +207,7 @@ class TestFastIdiv(unittest.TestCase):
@Context(DISABLE_FAST_IDIV=0)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
def test_fast_idiv_and_mod(self):
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(4))
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@ -242,7 +242,7 @@ class TestFastIdiv(unittest.TestCase):
@unittest.expectedFailure
def test_fast_idiv_overflow(self):
# This will be possible with a slightly different method for fast_idiv
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(8))
c = UOp.const(dtypes.uint, 7)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@ -253,7 +253,7 @@ class TestFastIdiv(unittest.TestCase):
self.assertNotIn(Ops.CDIV, ops)
def test_disable_fast_idiv(self):
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(4))
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.CDIV, dtypes.uint, (l, c))

View file

@ -20,7 +20,7 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = x.grad.schedule_linear(w.grad)
self.assertEqual(len(backward_schedule.src), 4)
self.assertEqual(len(backward_schedule.src), 2)
@unittest.skip("this requires optimizations")
def test_counters(self):

View file

@ -130,9 +130,15 @@ class PythonProgram:
elif u.op is Ops.CAST:
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
elif u.op is Ops.LOAD:
<<<<<<< shrink_in_render
load_sz = u.max_numel()
if load_sz > 1:
values[u] = [load([src_values[k][j] if k != 0 else src_values[k] \
=======
if (load_sz := u.max_numel()) > 1:
# buf and gate are not vecs
values[u] = [load([src_values[k] if k in [0,2] else src_values[k][j] \
>>>>>>> master
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
else:
values[u] = load(src_values, 0, u.dtype)

View file

@ -55,7 +55,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
return UOp.usum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|None:
if output is None: output = UOp.const(red.dtype, Invalid, red.device, red.shape).clone()
if output is None: output = UOp.const(red.dtype, Invalid, shape=red.shape).clone(device=red.device)
to = red.param_like(0)
src = buf.param_like(1)
red = src.allreduce(red.arg, red.src[1])

View file

@ -851,6 +851,7 @@ class Tensor(OpMixin):
# clear contexts
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient)):
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
if g.device is None and t.device is not None: g = g.clone(device=t.device)
if t.grad is None: t.grad = g
else: t.grad.assign(t.grad + g.to(t.grad.device))
return self

View file

@ -489,8 +489,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return perm.index(*non_slice_args, ptr=True)
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
def const_like(self, b:ConstLike, dtype:DType|None=None):
# constants can optionally have a DEVICE source
ret = UOp.const(dtype or self.dtype.base, b, device=self.device, shape=self.shard_shape if self.axis is not None else self._shape)
# multi constants can optionally have a DEVICE source # TODO: no const with DEVICE
dev = self.device if isinstance(self.device, tuple) else None
ret = UOp.const(dtype or self.dtype.base, b, device=dev, shape=self.shard_shape if self.axis is not None else self._shape)
return ret.multi(self.axis) if self.axis is not None else ret
def ufix(self, x):
if isinstance(x, UOp): return x
@ -761,7 +762,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.BUFFER: return AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into registers
if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into anonymous registers
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP}:
return self.src[0].addrspace
if self.op in GroupOp.Movement: return self.src[0].addrspace

View file

@ -3,7 +3,7 @@ from typing import cast, Any
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo, ParamArg
from tinygrad.uop.render import print_uops, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB, all_same
# ***** uop helpers *****
@ -170,7 +170,7 @@ spec_tensor = PatternMatcher([
# MULTI/MSELECT/MSTACK
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(s.device, str) for s in x.src) or (all_same(x.src) and x.src[0].device is None)),
# CONTIGUOUS ensures the source UOp realizes
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),