mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into shrink_in_render
This commit is contained in:
commit
ef9c60238e
7 changed files with 88 additions and 79 deletions
1
.github/actions/setup-tinygrad/action.yml
vendored
1
.github/actions/setup-tinygrad/action.yml
vendored
|
|
@ -308,7 +308,6 @@ runs:
|
|||
|
||||
CMAKE_ARGS="-Wno-dev -G Ninja -DOCELOT_BUILD_TOOLS=OFF -DCMAKE_BUILD_ALWAYS=0 -DBUILD_TESTS_CUDA=OFF -DCMAKE_POLICY_VERSION_MINIMUM=3.5"
|
||||
if [[ "${{ runner.os }}" == "macOS" ]]; then
|
||||
sudo xcode-select -s /Applications/Xcode_16.2.app/Contents/Developer
|
||||
CMAKE_ARGS="$CMAKE_ARGS -DBoost_INCLUDE_DIR=$(brew --prefix boost)/include -DBoost_LIBRARY_DIR=$(brew --prefix boost)/lib"
|
||||
else
|
||||
CMAKE_ARGS="$CMAKE_ARGS -DLLVM_DIR=$(llvm-config-15 --cmakedir)"
|
||||
|
|
|
|||
|
|
@ -4,18 +4,20 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
|||
|
||||
@functools.cache
|
||||
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
|
||||
vocab:int, rows:int, label_smoothing:float) -> UOp:
|
||||
vocab:int, rows:int, seq:int, label_smoothing:float) -> UOp:
|
||||
row = UOp.range(rows, 0)
|
||||
b = row // seq
|
||||
s = row % seq
|
||||
|
||||
v_max = UOp.range(vocab, 1, axis_type=AxisType.REDUCE)
|
||||
row_max = logits[row, v_max].cast(dtypes.float).reduce(v_max, arg=Ops.MAX)
|
||||
row_max = logits[b, s, v_max].cast(dtypes.float).reduce(v_max, arg=Ops.MAX)
|
||||
|
||||
v_lse = UOp.range(vocab, 2, axis_type=AxisType.REDUCE)
|
||||
row_lse = (logits[row, v_lse].cast(dtypes.float) - row_max).exp().reduce(v_lse, arg=Ops.ADD).log() + row_max
|
||||
row_lse = (logits[b, s, v_lse].cast(dtypes.float) - row_max).exp().reduce(v_lse, arg=Ops.ADD).log() + row_max
|
||||
|
||||
v_smooth = UOp.range(vocab, 3, axis_type=AxisType.REDUCE)
|
||||
target = logits[row, targets[row].cast(dtypes.weakint)].cast(dtypes.float)
|
||||
mean_logits = logits[row, v_smooth].cast(dtypes.float).reduce(v_smooth, arg=Ops.ADD) / vocab
|
||||
target = logits[b, s, targets[row].cast(dtypes.weakint)].cast(dtypes.float)
|
||||
mean_logits = logits[b, s, v_smooth].cast(dtypes.float).reduce(v_smooth, arg=Ops.ADD) / vocab
|
||||
loss = row_lse - (1.0 - label_smoothing) * target - label_smoothing * mean_logits
|
||||
stores = UOp.group(loss_out[row].store(loss), max_out[row].store(row_max), lse_out[row].store(row_lse))
|
||||
|
||||
|
|
@ -23,37 +25,42 @@ def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp
|
|||
|
||||
@functools.cache
|
||||
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
|
||||
vocab:int, rows:int, label_smoothing:float) -> UOp:
|
||||
vocab:int, rows:int, seq:int, label_smoothing:float) -> UOp:
|
||||
row = UOp.range(rows, 0)
|
||||
v = UOp.range(vocab, 1)
|
||||
b = row // seq
|
||||
s = row % seq
|
||||
|
||||
prob = (logits[row, v].cast(dtypes.float) - lse[row]).exp()
|
||||
prob = (logits[b, s, v].cast(dtypes.float) - lse[row]).exp()
|
||||
target = v.eq(targets[row].cast(dtypes.weakint)).where(1.0 - label_smoothing, 0.0)
|
||||
smooth = label_smoothing / vocab
|
||||
grad = (prob - target - smooth) * scale[0]
|
||||
|
||||
return d_logits[row, v].store(grad.cast(d_logits.dtype.base)).end(v, row).sink(arg=KernelInfo(f"fused_ce_loss_bwd_{rows}_{vocab}"))
|
||||
return d_logits[b, s, v].store(grad.cast(d_logits.dtype.base)).end(v, row).sink(arg=KernelInfo(f"fused_ce_loss_bwd_{rows}_{vocab}"))
|
||||
|
||||
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp, label_smoothing:float):
|
||||
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
|
||||
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
|
||||
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
|
||||
device = logits_u.device
|
||||
rows, VOCAB = logits_u.shape # (rows, VOCAB) after reshape
|
||||
MBS, SEQ, VOCAB = logits_u.shape
|
||||
if isinstance(device, tuple):
|
||||
axis = logits_u.axis
|
||||
ndev = len(device)
|
||||
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
||||
rows_per_dev = rows // ndev
|
||||
local_shape = tuple(s//ndev if i == axis else s for i,s in enumerate((MBS, SEQ, VOCAB)))
|
||||
d_logits = Tensor(Tensor.invalids(*local_shape, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
||||
rows_per_dev = local_shape[0] * local_shape[1]
|
||||
seq_per_dev = local_shape[1]
|
||||
else:
|
||||
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
|
||||
rows_per_dev = rows
|
||||
d_logits = Tensor.invalids(MBS, SEQ, VOCAB, dtype=dtypes.bfloat16, device=device)
|
||||
rows_per_dev = MBS * SEQ
|
||||
seq_per_dev = SEQ
|
||||
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
|
||||
scale = Tensor(gradient, device=device).float().reshape(-1)[0:1].contiguous()
|
||||
logits_t = Tensor(logits_u.after(kernel), device=device)
|
||||
lse_t = Tensor(lse_u.after(kernel), device=device)
|
||||
targets_t = Tensor(targets_u, device=device)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_bwd, vocab=VOCAB, rows=rows_per_dev, label_smoothing=label_smoothing)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_bwd, vocab=VOCAB, rows=rows_per_dev, seq=seq_per_dev, label_smoothing=label_smoothing)
|
||||
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
|
||||
return (None, None, None, d_logits.uop, None)
|
||||
|
||||
|
|
@ -73,17 +80,19 @@ def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> T
|
|||
device=logits.device)
|
||||
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
rows_per_dev = rows // ndev
|
||||
local_shape = tuple(s//ndev if i == axis else s for i,s in enumerate(logits.shape))
|
||||
rows_per_dev = local_shape[0] * local_shape[1]
|
||||
seq_per_dev = local_shape[1]
|
||||
else:
|
||||
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
rows_per_dev = rows
|
||||
logits_flat = logits.reshape(rows, VOCAB)
|
||||
seq_per_dev = SEQ
|
||||
targets_flat = targets.reshape(-1).cast(dtypes.int32)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_fwd, vocab=VOCAB, rows=rows_per_dev,
|
||||
fxn = functools.partial(_custom_fused_ce_loss_fwd, vocab=VOCAB, rows=rows_per_dev, seq=seq_per_dev,
|
||||
label_smoothing=label_smoothing)
|
||||
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
|
||||
loss_out, max_out, lse_out, logits_flat, targets_flat,
|
||||
loss_out, max_out, lse_out, logits, targets_flat,
|
||||
fxn=fxn, grad_fxn=functools.partial(_fused_ce_loss_bwd, label_smoothing=label_smoothing))
|
||||
return loss_out.mean()
|
||||
|
|
|
|||
|
|
@ -105,8 +105,8 @@ class TestModularWraparound(unittest.TestCase):
|
|||
|
||||
class TestGraphRewrite(unittest.TestCase):
|
||||
def test_dedup(self):
|
||||
v1 = UOp(Ops.DEFINE_VAR, dtypes.float)
|
||||
v2 = UOp(Ops.DEFINE_VAR, dtypes.float)
|
||||
v1 = UOp.variable("v", 0, 1, dtypes.float)
|
||||
v2 = UOp.variable("v", 0, 1, dtypes.float)
|
||||
nout = graph_rewrite(v1+v2, PatternMatcher([]))
|
||||
self.assertIs(nout.src[0], nout.src[1])
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
self.assertEqual(nout.arg, 3.0)
|
||||
|
||||
def test_depth_2_fold(self):
|
||||
v = UOp(Ops.DEFINE_VAR, dtypes.float)
|
||||
v = UOp.variable("v", 0, 1, dtypes.float)
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
nout = graph_rewrite(v+c1+c2, simple_pm)
|
||||
|
|
@ -191,7 +191,7 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
b = UOp.variable('b', 0, 1)
|
||||
c = UOp.variable('c', 0, 1)
|
||||
d = UOp.variable('d', 0, 1)
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ADD, a.dtype, src=(a.const_like(2), a)), (4+d)+c+(2+a)+b]
|
||||
outs = [2+a, 2+a+d+3+b+c+4, a.const_like(2)+a, (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, sym)
|
||||
print(sink.render())
|
||||
|
|
@ -203,7 +203,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_add_constant_fold(self):
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
out = UOp(Ops.ADD, dtypes.float, (c1, c2))
|
||||
out = c1+c2
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
|
|
@ -213,9 +213,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_where_same_fold(self):
|
||||
v = UOp.variable('tmp', 0, 1)
|
||||
c0 = UOp.const(dtypes.weakint, 0)
|
||||
vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0))
|
||||
vc = v != c0
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
|
||||
out = vc.where(c1, c1)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
|
|
@ -226,7 +226,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
bf = UOp.const(dtypes.bool, False)
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
out = UOp(Ops.WHERE, dtypes.float, (bf, c1, c2))
|
||||
out = bf.where(c1, c2)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
|
|
@ -235,7 +235,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_const_cast(self):
|
||||
bf = UOp.const(dtypes.bool, False)
|
||||
out = UOp(Ops.CAST, dtypes.int, (bf,))
|
||||
out = bf.cast(dtypes.int)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
|
|
@ -244,7 +244,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_const_bitcast(self):
|
||||
bf = UOp.const(dtypes.float, 1.0)
|
||||
out = UOp(Ops.BITCAST, dtypes.uint32, (bf,))
|
||||
out = bf.bitcast(dtypes.uint32)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
|
|
@ -254,7 +254,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
@unittest.expectedFailure
|
||||
def test_const_shape_change_bitcast(self):
|
||||
bf = UOp.const(dtypes.uint8, 0x3F)
|
||||
out = UOp(Ops.BITCAST, dtypes.half, (bf,))
|
||||
out = bf.bitcast(dtypes.half)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
|
||||
|
|
@ -262,7 +262,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_noop_vectorize_fold(self):
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
ld = d0.load(idx, dtype=dtypes.float.vec(2))
|
||||
vec = UOp(Ops.STACK, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0)
|
||||
alu = UOp(Ops.SQRT, dtypes.float, (x, ))
|
||||
|
|
@ -278,7 +278,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
idx = UOp.const(dtypes.int, 0)
|
||||
def _test_vec(geps, count=4):
|
||||
vec = UOp(Ops.STACK, dtypes.float.vec(count), geps)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), vec))
|
||||
out = d0.index(idx).store(vec)
|
||||
uops = to_uops_list([out])
|
||||
if DEBUG >= 4:
|
||||
from tinygrad import Device
|
||||
|
|
@ -286,28 +286,28 @@ class TestUOpGraph(unittest.TestCase):
|
|||
return uops[-2].src[-1] # -2 to skip SINK
|
||||
|
||||
# possible
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
|
||||
xyzw = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in range(4))
|
||||
val = d1.index(idx).load(dtype=dtypes.float.vec(4))
|
||||
xyzw = tuple(val.gep(i) for i in range(4))
|
||||
self.assertIs(_test_vec(xyzw).op, Ops.LOAD)
|
||||
|
||||
# unaligned
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
|
||||
wzyx = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4)))
|
||||
val = d1.index(idx).load(dtype=dtypes.float.vec(4))
|
||||
wzyx = tuple(val.gep(i) for i in reversed(range(4)))
|
||||
self.assertIs(_test_vec(wzyx).op, Ops.STACK)
|
||||
|
||||
# different_size
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1.index(idx),))
|
||||
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
|
||||
val = d1.index(idx).load(dtype=dtypes.float.vec(2))
|
||||
xy = tuple(val.gep(i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy+xy).op, Ops.STACK)
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
|
||||
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
|
||||
val = d1.index(idx).load(dtype=dtypes.float.vec(4))
|
||||
xy = tuple(val.gep(i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy, count=2).op, Ops.STACK)
|
||||
|
||||
# different vals
|
||||
val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1.index(idx),))
|
||||
val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2.index(idx),))
|
||||
xy1 = tuple(UOp(Ops.GEP, dtypes.float, (val1, ), (i,)) for i in range(2))
|
||||
xy2 = tuple(UOp(Ops.GEP, dtypes.float, (val2, ), (i,)) for i in range(2))
|
||||
val1 = d1.index(idx).load(dtype=dtypes.float.vec(2))
|
||||
val2 = d2.index(idx).load(dtype=dtypes.float.vec(2))
|
||||
xy1 = tuple(val1.gep(i) for i in range(2))
|
||||
xy2 = tuple(val2.gep(i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy1+xy2).op, Ops.STACK)
|
||||
|
||||
def test_gep_vec_const_fold(self):
|
||||
|
|
@ -323,7 +323,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(Ops.STACK, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
|
||||
var = UOp.variable("var", 0, 1, dtypes.half.vec(i))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
|
|
@ -331,7 +331,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
|
||||
var = UOp.variable("var", 0, 1, dtypes.half.vec(i))
|
||||
vec = UOp(Ops.STACK, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
|
|
@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
|
||||
out = d0.index(idx, ptr=True).store(alu)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
|
||||
out = d0.index(idx, ptr=True).store(alu)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
|
|
@ -403,8 +403,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
v = UOp.variable("tmp", 0, 1, dtypes.int)
|
||||
c2 = UOp.const(dtypes.int, 2)
|
||||
c4 = UOp.const(dtypes.int, 4)
|
||||
vc = UOp(Ops.ADD, dtypes.int, (v, c2))
|
||||
out = UOp(Ops.ADD, dtypes.int, (vc, c4))
|
||||
vc = v+c2
|
||||
out = vc+c4
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 4) # +1 for SINK
|
||||
out = uops[-2] # -2 to skip SINK
|
||||
|
|
@ -535,19 +535,19 @@ class TestUOpGraph(unittest.TestCase):
|
|||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx, ptr=True), ld1+ld0))])
|
||||
uops = to_uops_list([glbl0.index(idx, ptr=True).store(ld1+ld0)])
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(len([u for u in uops if u.op is Ops.LOAD]), 1)
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(16))
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
lidx = UOp.special(16, "lidx0", dtypes.int)
|
||||
st = smem.index(lidx, ptr=True).store(glbl0.index(lidx, ptr=True).load())
|
||||
barrier = st.barrier()
|
||||
ld0 = smem.after(barrier).index(UOp.invalid())
|
||||
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx, ptr=True), ld1+ld0))])
|
||||
uops = to_uops_list([glbl0.index(lidx, ptr=True).store(ld1+ld0)])
|
||||
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(len([u for u in uops if u.op is Ops.LOAD]), 2)
|
||||
|
|
|
|||
|
|
@ -145,12 +145,12 @@ class CStyleLanguage(Renderer):
|
|||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
tmp = ""
|
||||
if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs):
|
||||
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
||||
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
launch_bounds = prod([d.vmax for d in local_dims])
|
||||
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
|
||||
|
|
@ -173,14 +173,14 @@ class CStyleLanguage(Renderer):
|
|||
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
||||
|
||||
def __getitem__(self, key): return self.r[key] # hacky helper
|
||||
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]:
|
||||
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[UOp,bool]]]]:
|
||||
r: dict[UOp, str] = {}
|
||||
self.r = r
|
||||
|
||||
child_count = Counter(v for ru in uops for v in ru.src)
|
||||
# find which PARAMs are stored to with a single toposort
|
||||
writable_params = {u for u in UOp.sink(*[u.src[0] for u in uops if u.op is Ops.STORE]).toposort(lambda u: u.op != Ops.END) if u.op is Ops.PARAM}
|
||||
bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
|
||||
bufs: dict[UOp, tuple[str, tuple[UOp, bool]]] = {}
|
||||
kernel = []
|
||||
depth = 1
|
||||
c: defaultdict[str, int] = defaultdict(int)
|
||||
|
|
@ -197,8 +197,7 @@ class CStyleLanguage(Renderer):
|
|||
if u.op is not Ops.PARAM: r[u] = u.arg[0]
|
||||
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
|
||||
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
|
||||
bufs[u] = (r[u], ((u.dtype if isinstance(u.dtype, (ImageDType, PtrDType)) else u.dtype.ptr(u.max_numel(), u.addrspace))
|
||||
if u.op is Ops.PARAM else u.dtype, u in writable_params))
|
||||
bufs[u] = (r[u], (u, u in writable_params))
|
||||
continue
|
||||
|
||||
# naming
|
||||
|
|
@ -287,7 +286,7 @@ class ClangRenderer(CStyleLanguage):
|
|||
AMX_SET(1);\n return data0;\n}}"""]
|
||||
return prefix
|
||||
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str: return ""
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
defines = '\n'.join(self._render_defines(uops))
|
||||
|
|
|
|||
|
|
@ -180,12 +180,12 @@ class LLVMRenderer(Renderer):
|
|||
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
||||
size = u.max_numel()
|
||||
if u.op is Ops.DEFINE_REG:
|
||||
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype)}]")
|
||||
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype.base)}]")
|
||||
elif self.has_local:
|
||||
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{size} x {ldt(u.dtype)}] undef, align 16")
|
||||
kernel.append(f" {r[u]} = addrspacecast [{size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{size} x {ldt(u.dtype)}]*")
|
||||
else:
|
||||
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype)}], align 16")
|
||||
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype.base)}], align 16")
|
||||
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
||||
elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
|
||||
r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast
|
||||
|
|
|
|||
|
|
@ -117,8 +117,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
|
||||
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
|
||||
assert sys.platform != 'win32'
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, system, DEBUG, suppress_finalizing, Target
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
|
|
@ -53,18 +53,19 @@ class DSPRenderer(ClangRenderer):
|
|||
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
||||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
|
||||
if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -273,22 +274,23 @@ return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}'''
|
|||
class MockDSPRenderer(DSPRenderer):
|
||||
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
|
||||
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
|
||||
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
|
||||
msrc = [mockdsp_boilerplate, 'void _start(void) {']
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType):
|
||||
sz = b[1][0].size*b[1][0].itemsize
|
||||
if isinstance(b[1][0].dtype, PtrDType):
|
||||
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
|
||||
# for loop for big reads
|
||||
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
|
||||
else:
|
||||
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
|
||||
msrc.append("unsigned int st = inscount();")
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
|
||||
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
|
||||
msrc.append(f"{function_name}({', '.join(params)});")
|
||||
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")
|
||||
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
|
||||
msrc.append('exit(0); }')
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue