mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into shrink_in_render
This commit is contained in:
commit
46541d70f4
15 changed files with 95 additions and 81 deletions
|
|
@ -25,10 +25,12 @@ def calculate_storage_offset(x: Tensor) -> int:
|
|||
u_strides = strides_for_shape(u.src[0].shape)
|
||||
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
|
||||
return offset
|
||||
def wrap(x: Tensor) -> torch.Tensor:
|
||||
def wrap(x: Tensor, dev: torch.device|None=None) -> torch.Tensor:
|
||||
x._strides = strides_for_shape(x.shape) # always recalculate
|
||||
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
|
||||
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
|
||||
# a deviceless tinygrad value takes the device from the op context
|
||||
idx = _to_torch_device(x.device).index if x.device is not None else (dev.index if dev is not None else 0)
|
||||
return mod.wrap(x, _to_torch_dtype(x.dtype), idx)
|
||||
def _update_torch_metadata(tensor: torch.Tensor, tiny: Tensor) -> None:
|
||||
tiny._strides = strides_for_shape(tiny.shape)
|
||||
tiny._storage_offset = calculate_storage_offset(tiny)
|
||||
|
|
@ -545,14 +547,17 @@ def wrap_out(f):
|
|||
assigned = f(*args, **kwargs)
|
||||
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
|
||||
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
|
||||
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.device == assigned.device or out.device is None or assigned.device is None, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
if out.device is None and assigned.device is not None: out.replace(out.empty_like(device=assigned.device))
|
||||
return out.assign(assigned)
|
||||
return _wrap_out
|
||||
|
||||
def _inplace_op(t, new_value):
|
||||
if not hasattr(t, "_view_base") and not getattr(canonical_base(t), "_views", set()): t.replace(new_value)
|
||||
else: _apply_inplace(t, new_value)
|
||||
else:
|
||||
if (base:=canonical_base(t)).device is None and new_value.device is not None: base.replace(base.empty_like(device=new_value.device))
|
||||
_apply_inplace(t, new_value)
|
||||
return t
|
||||
|
||||
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
|
|
@ -679,10 +684,11 @@ def wrap_fxn(k,f):
|
|||
if TORCH_DEBUG:
|
||||
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
|
||||
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
|
||||
dev = next((a.device for a in args if isinstance(a, torch.Tensor) and a.device.type == "tiny"), None)
|
||||
args, kwargs = unwrap_args(args, kwargs)
|
||||
out = f(*args, **kwargs)
|
||||
if isinstance(out, Tensor): return wrap(out)
|
||||
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
|
||||
if isinstance(out, Tensor): return wrap(out, dev)
|
||||
elif isinstance(out, tuple): return tuple(wrap(x, dev) for x in out)
|
||||
else: raise RuntimeError(f"unknown output type {type(out)}")
|
||||
return nf
|
||||
|
||||
|
|
|
|||
|
|
@ -22,30 +22,30 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None):
|
|||
|
||||
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
||||
dtype = alu_src_uops[0].dtype
|
||||
a = UOp.param(0, dtype.ptr())
|
||||
b = UOp.param(1, dtype.ptr())
|
||||
a = UOp.param(0, dtype.ptr(1))
|
||||
b = UOp.param(1, dtype.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = b.index(idx)
|
||||
ld = b.index(idx, ptr=True).load()
|
||||
alu = ld.alu(alu_op, *alu_src_uops)
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
store = UOp.store(a.index(idx, ptr=True), alu)
|
||||
return _test_uop_result([Tensor([input_val])], UOp(Ops.SINK, dtypes.void, (store,), arg=KernelInfo()))[0]
|
||||
|
||||
class TestRendererFailures(unittest.TestCase):
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
a = UOp.param(0, dtypes.int.ptr(4))
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu), ptr=True), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
|
||||
ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
def test_gated_store_with_alu_2d(self):
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
a = UOp.param(0, dtypes.int.ptr(8))
|
||||
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1), ptr=True), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
|
||||
ret = _test_uop_result([], sink, local_size=[4, 2, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:
|
|||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
|
||||
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(1), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
|
||||
|
|
@ -42,7 +42,7 @@ def _test_single_value(vals, op, dts):
|
|||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
|
||||
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = buf_store[UOp.const(dtypes.int32, 0)].store(alu)
|
||||
|
|
@ -54,7 +54,7 @@ def _test_single_value_const(vals, op, dts):
|
|||
|
||||
def _test_uops_result(output_dtype, uops, res):
|
||||
# uops = []
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
|
||||
# res = output_fn(uops)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
|
|
@ -221,8 +221,8 @@ class TestLocalAccess(unittest.TestCase):
|
|||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_bitshift_left(self):
|
||||
g1 = UOp.param(0, dtypes.int32.ptr())
|
||||
out = UOp.param(1, dtypes.int32.ptr())
|
||||
g1 = UOp.param(0, dtypes.int32.ptr(3))
|
||||
out = UOp.param(1, dtypes.int32.ptr(2))
|
||||
c1 = UOp.const(dtypes.int, 2)
|
||||
c2 = UOp.const(dtypes.int, 3)
|
||||
l1 = g1.index(c1)
|
||||
|
|
@ -249,7 +249,7 @@ class TestAssembly(unittest.TestCase):
|
|||
self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
|
||||
|
||||
def test_mulacc_shl(self):
|
||||
g1 = UOp.param(0, dtypes.int32.ptr())
|
||||
g1 = UOp.param(0, dtypes.int32.ptr(2))
|
||||
c1 = UOp.const(dtypes.int, 0)
|
||||
c2 = UOp.const(dtypes.int, 1)
|
||||
expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2)
|
||||
|
|
@ -258,7 +258,7 @@ class TestAssembly(unittest.TestCase):
|
|||
self.assertIn(Ops.MULACC, [x.op for x in uops])
|
||||
|
||||
def test_use_cmpeq(self):
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(8))
|
||||
c = UOp.const(dtypes.uint, 7)
|
||||
comp = g.index(c).ne(c).ne(True)
|
||||
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple
|
|||
for buf_dt, data in inputs or []:
|
||||
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
|
||||
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
|
||||
g = UOp.param(0, uop.dtype.ptr())
|
||||
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
|
||||
g = UOp.param(0, uop.dtype.ptr(1))
|
||||
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0), ptr=True), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
|
||||
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg))
|
||||
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals)
|
||||
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]
|
||||
|
|
|
|||
|
|
@ -968,7 +968,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_const_schedule_contig(self):
|
||||
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
|
||||
check_schedule(constv, 1)
|
||||
check_schedule(constv, 0)
|
||||
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
|||
def test_payne_hanek_reduction(self):
|
||||
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
|
||||
# Load input value from a buffer to prevent constant folding
|
||||
input_buf = UOp.param(1, dtypes.double.ptr())
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
|
||||
input_buf = UOp.param(1, dtypes.double.ptr(1))
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0), ptr=True).load()
|
||||
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
|
||||
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))
|
||||
|
||||
|
|
|
|||
|
|
@ -380,22 +380,22 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp.param(0, dtypes.bool.ptr())
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
d0 = UOp.param(0, dtypes.bool.ptr(1))
|
||||
d1 = UOp.param(1, dtypes.int.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
def test_double_cast_fold(self):
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
d0 = UOp.param(0, dtypes.float.ptr(1))
|
||||
d1 = UOp.param(1, dtypes.int.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
|
|
@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_bitcast_to_same_dtype_fold(self):
|
||||
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
|
||||
d0 = UOp.param(0, dt.ptr())
|
||||
d0 = UOp.param(0, dt.ptr(1))
|
||||
v = d0.index(UOp.const(dtypes.int, 0))
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
|
||||
|
|
@ -427,18 +427,18 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
d0 = UOp.param(0, dtypes.long.ptr(100))
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
out = UOp.param(1, dtypes.long.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.long.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
|
||||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
d0 = UOp.param(0, dtypes.long.ptr(100))
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
|
||||
w = (ridx0<50).where(5, ld)
|
||||
uops = to_uops_list([w])
|
||||
|
|
@ -448,40 +448,40 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_where_on_gated_load_with_cast(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.int.ptr())
|
||||
d0 = UOp.param(0, dtypes.int.ptr(100))
|
||||
gate_idx = ridx0.valid((ridx0<50))
|
||||
ld = d0.index(gate_idx).cast(dtypes.float)
|
||||
w = (ridx0<50).where(ld, 5.0)
|
||||
out = UOp.param(1, dtypes.float.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.float.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d0 = UOp.param(0, dtypes.float.ptr(100))
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half)
|
||||
out = UOp.param(1, dtypes.half.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.half.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond_swapped(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d0 = UOp.param(0, dtypes.float.ptr(100))
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half)
|
||||
out = UOp.param(1, dtypes.half.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.half.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
||||
def test_where_in_store_becomes_gate(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
idx = d0.index(ridx0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr(100))
|
||||
idx = d0.index(ridx0, ptr=True)
|
||||
ld = idx.load()
|
||||
val = (ridx0<50).where(5, ld)
|
||||
st = idx.store(val).end(ridx0)
|
||||
|
|
@ -529,33 +529,33 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
glbl1 = UOp.param(1, dtypes.int.ptr())
|
||||
glbl2 = UOp.param(2, dtypes.int.ptr())
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(1))
|
||||
glbl1 = UOp.param(1, dtypes.int.ptr(1))
|
||||
glbl2 = UOp.param(2, dtypes.int.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx, ptr=True), ld1+ld0))])
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(16))
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = smem.after(barrier).index(UOp.invalid())
|
||||
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx, ptr=True), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp.param(0, dtypes.int.ptr())
|
||||
glbl = UOp.param(0, dtypes.int.ptr(1))
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
|
|
|
|||
|
|
@ -951,8 +951,8 @@ class TestSymbolic(unittest.TestCase):
|
|||
expr = cond.where(a, b).cast(dtypes.half)
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp.param(0, dtypes.int.ptr())
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
glbl = UOp.param(0, dtypes.int.ptr(1))
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
|
|
|||
|
|
@ -110,10 +110,10 @@ class TestExecALU(unittest.TestCase):
|
|||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp.param(0, dtypes.float.ptr())
|
||||
gmem = UOp.param(0, dtypes.float.ptr(8))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val))
|
||||
uops = to_uops_list([store])
|
||||
|
|
@ -126,12 +126,12 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
self.assertEqual(len(gated_uops[-1].src), 2)
|
||||
|
||||
def test_gate_some_stores(self):
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr())
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr())
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr(8))
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr(8))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
|
|
@ -146,13 +146,13 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
||||
@unittest.skip("we don't merge ifs anymore")
|
||||
def test_merge_ifs_alt(self):
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr())
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr())
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr(8))
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr(8))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gate)))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx.valid(gate)))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gate)))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx.valid(gate)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
|
|
@ -170,7 +170,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
class TestFastIdiv(unittest.TestCase):
|
||||
def test_division_power_of_two(self):
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp.param(0, dt.ptr())
|
||||
g = UOp.param(0, dt.ptr(3))
|
||||
c = UOp.const(dt, 2)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dt, (l, c))
|
||||
|
|
@ -183,7 +183,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
def test_floormod_power_of_two(self):
|
||||
# FLOORMOD by a power of two lowers to AND (correct floor mod for any sign in two's complement)
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp.param(0, dt.ptr())
|
||||
g = UOp.param(0, dt.ptr(9))
|
||||
c = UOp.const(dt, 8)
|
||||
a = UOp(Ops.FLOORMOD, dt, (g.index(c), c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
|
|
@ -195,7 +195,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
def test_floordiv_power_of_two_uint(self):
|
||||
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
|
||||
for dt in (dtypes.uint32, dtypes.uint64):
|
||||
g = UOp.param(0, dt.ptr())
|
||||
g = UOp.param(0, dt.ptr(3))
|
||||
c = UOp.const(dt, 2)
|
||||
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
|
|
@ -207,7 +207,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
@Context(DISABLE_FAST_IDIV=0)
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
|
||||
def test_fast_idiv_and_mod(self):
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(4))
|
||||
c = UOp.const(dtypes.uint, 3)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
|
@ -242,7 +242,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
@unittest.expectedFailure
|
||||
def test_fast_idiv_overflow(self):
|
||||
# This will be possible with a slightly different method for fast_idiv
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(8))
|
||||
c = UOp.const(dtypes.uint, 7)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
|
@ -253,7 +253,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
self.assertNotIn(Ops.CDIV, ops)
|
||||
|
||||
def test_disable_fast_idiv(self):
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(4))
|
||||
c = UOp.const(dtypes.uint, 3)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestWinograd(unittest.TestCase):
|
|||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = x.grad.schedule_linear(w.grad)
|
||||
self.assertEqual(len(backward_schedule.src), 4)
|
||||
self.assertEqual(len(backward_schedule.src), 2)
|
||||
|
||||
@unittest.skip("this requires optimizations")
|
||||
def test_counters(self):
|
||||
|
|
|
|||
|
|
@ -130,9 +130,15 @@ class PythonProgram:
|
|||
elif u.op is Ops.CAST:
|
||||
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
|
||||
elif u.op is Ops.LOAD:
|
||||
<<<<<<< shrink_in_render
|
||||
load_sz = u.max_numel()
|
||||
if load_sz > 1:
|
||||
values[u] = [load([src_values[k][j] if k != 0 else src_values[k] \
|
||||
=======
|
||||
if (load_sz := u.max_numel()) > 1:
|
||||
# buf and gate are not vecs
|
||||
values[u] = [load([src_values[k] if k in [0,2] else src_values[k][j] \
|
||||
>>>>>>> master
|
||||
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
|
||||
else:
|
||||
values[u] = load(src_values, 0, u.dtype)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
|||
return UOp.usum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
|
||||
def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|None:
|
||||
if output is None: output = UOp.const(red.dtype, Invalid, red.device, red.shape).clone()
|
||||
if output is None: output = UOp.const(red.dtype, Invalid, shape=red.shape).clone(device=red.device)
|
||||
to = red.param_like(0)
|
||||
src = buf.param_like(1)
|
||||
red = src.allreduce(red.arg, red.src[1])
|
||||
|
|
|
|||
|
|
@ -851,6 +851,7 @@ class Tensor(OpMixin):
|
|||
# clear contexts
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
if g.device is None and t.device is not None: g = g.clone(device=t.device)
|
||||
if t.grad is None: t.grad = g
|
||||
else: t.grad.assign(t.grad + g.to(t.grad.device))
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -489,8 +489,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return perm.index(*non_slice_args, ptr=True)
|
||||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike, dtype:DType|None=None):
|
||||
# constants can optionally have a DEVICE source
|
||||
ret = UOp.const(dtype or self.dtype.base, b, device=self.device, shape=self.shard_shape if self.axis is not None else self._shape)
|
||||
# multi constants can optionally have a DEVICE source # TODO: no const with DEVICE
|
||||
dev = self.device if isinstance(self.device, tuple) else None
|
||||
ret = UOp.const(dtype or self.dtype.base, b, device=dev, shape=self.shard_shape if self.axis is not None else self._shape)
|
||||
return ret.multi(self.axis) if self.axis is not None else ret
|
||||
def ufix(self, x):
|
||||
if isinstance(x, UOp): return x
|
||||
|
|
@ -761,7 +762,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.BUFFER: return AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into registers
|
||||
if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into anonymous registers
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP}:
|
||||
return self.src[0].addrspace
|
||||
if self.op in GroupOp.Movement: return self.src[0].addrspace
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import cast, Any
|
|||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo, ParamArg
|
||||
from tinygrad.uop.render import print_uops, pyrender
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB, all_same
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ spec_tensor = PatternMatcher([
|
|||
# MULTI/MSELECT/MSTACK
|
||||
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(s.device, str) for s in x.src) or (all_same(x.src) and x.src[0].device is None)),
|
||||
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue