Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
ece4604825 cleanups 2025-07-23 16:36:37 -07:00
George Hotz
0cc71fec59 bugfix to match speed 2025-07-23 16:24:38 -07:00
George Hotz
1e56fb559a gemm passes spec 2025-07-23 15:50:01 -07:00
George Hotz
a87bc6dacc matmul is correct 2025-07-23 15:35:57 -07:00
George Hotz
68db0e2338 write out kernel 3 in uops 2025-07-23 14:37:48 -07:00
6 changed files with 126 additions and 150 deletions

View file

@ -29,7 +29,7 @@ if __name__ == "__main__":
c = Tensor.zeros(N, N).contiguous().realize() c = Tensor.zeros(N, N).contiguous().realize()
GlobalCounters.reset() GlobalCounters.reset()
with Context(DEBUG=2, BEAM=4): with Context(DEBUG=2):
for _ in range(run_count): tc = (a@b).realize() for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset() GlobalCounters.reset()

View file

@ -80,6 +80,8 @@ extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, fl
// Iteration over BK blocks. // Iteration over BK blocks.
for (int kId = 0; kId < N; kId += BK) { for (int kId = 0; kId < N; kId += BK) {
__syncthreads();
// We populate the Shared Memory with Ks row and columns // We populate the Shared Memory with Ks row and columns
for (int i = 0; i < nbReadsB; i++) { for (int i = 0; i < nbReadsB; i++) {
int index_x = BN * blockIdx.x + rBIdx; int index_x = BN * blockIdx.x + rBIdx;
@ -123,7 +125,6 @@ extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, fl
} }
} }
} }
__syncthreads();
} }
for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) {

View file

@ -1,168 +1,145 @@
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.helpers import prod, unwrap
from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.opt.kernel import AxisType
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, GroupOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.schedule.kernelize import merge_views
from tinygrad.shape.view import View
from tinygrad.dtype import AddrSpace from tinygrad.dtype import AddrSpace
N = 4096 N = 4096
run_count = 5 run_count = 5
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape. def hand_spec_kernel3():
# src->r->view --> src->view->r BLOCK_SIZE = 256
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
if r.tag is not None: return None
# confirm the input is in order
# TODO: replace this with a UOp that allows for nothing else then remove this
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
# append the reduce shape to each of the views BN = 128
reduce_count = len(r.axis_arg) BM = 128
prshape = prod(rshape:=src.shape[-reduce_count:]) BK = 8
rstrides = strides_for_shape(rshape)
nv = [View.create(v.shape[:-reduce_count]+rshape, tuple(x*prshape for x in v.strides[:-reduce_count])+rstrides, v.offset*prshape,
v.mask[:-reduce_count]+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
# no reshape required with shrinking REDUCE_AXIS
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
(r.arg[0], tuple(range(len(view.shape)-reduce_count, len(view.shape)))))
early_view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.VALID, Ops.STORE, Ops.LOAD}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src)) if e.tag is None else None),
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
])
def hand_spec():
# Block Tile size . 128x128
# Thread Tile size . 4x4
# Wave Tile size . 128x32
# A wave is . 8x4
# ────── problem size and tiling params (mirror the C kernel) ───────────────────
BK = 8 # depth of K-tile
BN = BM = 128 # block-tile (output) sizes
# the real thread is 16x8 = 128 regs
TM = 4
nbIterWaveM = 2
TN = 4 TN = 4
nbIterWaveN = 4 TM = 4
# ────── shared-memory tile sizes (unchanged) ─────────────────────────────────── nbWaves = BLOCK_SIZE // 32
LDS_A_SZ = BK * BM # 1024 floats WN = 64
LDS_B_SZ = BK * BN # 1024 floats WM = BN * BM // nbWaves // WN
bC = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0) # output C nbWaveX = BN // WN
bA = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1) # input A nbWaveY = BM // WM
bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B
# TODO: this should not be a string, just a number threadIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("lidx0", BLOCK_SIZE))
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, addrspace=AddrSpace.LOCAL), arg="As") waveIndex = threadIdx_x // 32
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, addrspace=AddrSpace.LOCAL), arg="Bs") waveIdx = waveIndex % nbWaveX
waveIdy = waveIndex // nbWaveX
indexInWave = threadIdx_x % 32
s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1)) nbThreadXPerWave = 8
s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N)) nbThreadYPerWave = 4
s2 = ShapeTracker.from_shape((N, N, 1), (N, 1, 0))
ls0 = ShapeTracker.from_shape((BM, BK)) idxInWave = indexInWave % nbThreadXPerWave
ls1 = ShapeTracker.from_shape((BN, BK)) idyInWave = indexInWave // nbThreadXPerWave
buf_at = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST] nbIterWaveN = WN // (nbThreadXPerWave * TN)
buf_bt = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST] nbIterWaveM = WM // (nbThreadYPerWave * TM)
axis_types = buf_at + buf_bt + [AxisType.REDUCE, AxisType.UNROLL, AxisType.UNROLL, AxisType.UNROLL]
# 128 x 128 x 8 SUBWN = WN // nbIterWaveN
full_shape = (N//BM, 2, 2, 2, 2, 2, 2, 2, N//BN, 2, 2, 2, 2, 2, 2, 2, N//BK, 2, 2, 2) SUBWM = WM // nbIterWaveM
s0 = s0.reshape(full_shape) # Thread mapping to read BKxBN block from A
s1 = s1.reshape(full_shape) rAIdx = threadIdx_x % BK
s2 = s2.reshape(full_shape[:-4] + (1,)*4) rAIdy = threadIdx_x // BK
# Thread mapping to read BNxBK block from B
rBIdx = threadIdx_x % BN
rBIdy = threadIdx_x // BN
ls0 = ls0.reshape((1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2)).expand(s0.shape) strideReadB = BLOCK_SIZE // BN
ls1 = ls1.reshape((1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2)).expand(s1.shape) strideReadA = BLOCK_SIZE // BK
assert ls0.real_size() == LDS_A_SZ nbReadsB = BN * BK // BLOCK_SIZE
assert ls1.real_size() == LDS_B_SZ nbReadsA = BM * BK // BLOCK_SIZE
# BK is a loop of 8 blockIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx0", N//BN))
# each loop reads 8 in A, 16 in B blockIdx_y = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx1", N//BM))
print(ls0) a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
print(ls1) b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
permaxis = [] junk = UOp.const(dtypes.float, 0)
for axis_order in [AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST, AxisType.GROUP_REDUCE, AxisType.REDUCE, AxisType.UNROLL]: A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
permaxis += [i for i,a in enumerate(axis_types) if a == axis_order] B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
axis_types = [axis_types[x] for x in permaxis]
s0, s1, s2, ls0, ls1 = [x.permute(tuple(permaxis)) for x in [s0, s1, s2, ls0, ls1]]
print(axis_types)
lw0, lr0 = ls0, ls0 As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
lw1, lr1 = ls1, ls1 Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
# first round of permutes c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
permaxis = (0, 1, 19, 18, 17, 12, 11, 10, 5, 4, 3, 2, 6, 7, 8, 9, 16, 13, 14, 15) kId_range = UOp.range(dtypes.int, N//BK, 0)
s0 = s0.permute(permaxis) kId = kId_range*BK
lw0 = lw0.permute(permaxis)
permaxis = (0, 1, 15, 14, 9, 8, 7, 6, 13, 19, 18, 17, 5, 4, 3, 2, 16, 12, 11, 10) # load from globals into locals
s1 = s1.permute(permaxis) i = UOp.range(dtypes.int, nbReadsB, 1)
lw1 = lw1.permute(permaxis) index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
# second round of permutes i = UOp.range(dtypes.int, nbReadsA, 2)
#permaxis = (0, 1, 12, 11, 5, 4, 3, 2, 10, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 19) index_x = rAIdx + kId
#lw0 = lw0.permute(permaxis) index_y = BM * blockIdx_y + rAIdy + i * strideReadA
#lr0 = lr0.permute(permaxis) As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i)
from tinygrad.opt.kernel import axis_colors, colored barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s0.shape, s0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s1.shape, s1.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s2.shape, s2.views[0].strides, axis_types)]))
print("lw")
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw0.shape, lw0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw1.shape, lw1.views[0].strides, axis_types)]))
print("lr")
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr0.shape, lr0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr1.shape, lr1.views[0].strides, axis_types)]))
# loads and stores k = UOp.range(dtypes.int, BK, 3)
bs0 = bA.view(s0).load()
bs1 = bB.view(s1).load()
bs0 = lAs.view(lr0).load(lAs.view(lw0).store(bs0))
bs1 = lBs.view(lr1).load(lBs.view(lw1).store(bs1))
mat = (bs0 * bs1).r(Ops.ADD, tuple([i for i,a in enumerate(axis_types) if a in (AxisType.REDUCE, AxisType.UNROLL)]), permute=False) # load from locals into registers
st = bC.view(s2).store(mat) iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
i = UOp.range(dtypes.int, TN, 5)
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
ast = st.sink(arg=KernelInfo(axis_types=tuple(axis_types), name="tinygemm")) iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
ast = graph_rewrite(ast, merge_views) i = UOp.range(dtypes.int, TM, 7)
prg = get_program(ast, Device.default.renderer) index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
print(prg.src) A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i)
return prg
# do the GEMM math
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 9)
yt = UOp.range(dtypes.int, TM, 10)
xt = UOp.range(dtypes.int, TN, 11)
x = iterWaveN * TN + xt
y = iterWaveM * TM + yt
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
sink = c_regs_idx.store(c_regs_idx.load() + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
iterWaveM, iterWaveN, yt, xt, k, kId_range)
# store c_regs into c
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 12)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 13)
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
yt = UOp.range(dtypes.int, TM, 14)
xt = UOp.range(dtypes.int, TN, 15)
indexC = N * (yOut + yt) + xOut + xt
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink), iterWaveM, iterWaveN, yt, xt)
return sink.sink(arg=KernelInfo(name="tinygemm"))
if __name__ == "__main__": if __name__ == "__main__":
hprg = hand_spec() hprg = hand_spec_kernel3()
hrunner = CompiledRunner(hprg) prg = get_program(hprg, Device.default.renderer)
print(prg.src)
hrunner = CompiledRunner(prg)
a = Tensor.randn(N, N).realize() a = Tensor.randn(N, N).realize()
b = Tensor.randn(N, N).realize() b = Tensor.randn(N, N).realize()
hc = Tensor.zeros(N, N).contiguous().realize() hc = Tensor.zeros(N, N).contiguous().realize()
GlobalCounters.reset() GlobalCounters.reset()
with Context(DEBUG=2, BEAM=4): with Context(DEBUG=2):
for _ in range(run_count): tc = (a@b).realize() for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset() GlobalCounters.reset()
ei = ExecItem(hrunner, [hc.uop.buffer, a.uop.buffer, b.uop.buffer]) ei = ExecItem(hrunner, [a.uop.buffer, b.uop.buffer, hc.uop.buffer])
with Context(DEBUG=2): with Context(DEBUG=2):
for _ in range(run_count): ei.run(wait=True) for _ in range(run_count): ei.run(wait=True)
err = (hc-tc).square().mean().item() err = (hc-tc).square().mean().item()
print(f"hrunner {err}") print(f"hrunner {err}")
assert err < 1e-06 if err > 1e-06: raise RuntimeError("matmul is wrong!")

View file

@ -211,6 +211,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs) def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def __getitem__(self, idx): return self.index(idx)
def const_like(self, b:ConstLike): def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source # constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None) return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None)

View file

@ -162,10 +162,8 @@ spec = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),)), lambda: True), (UPat(Ops.LOAD, src=(UPat(Ops.STORE),)), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?> # LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(index_pat,)), validate_index), (UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.BARRIER))), validate_index), (UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond"))), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat, UPat.var("alt")), name="ld"), lambda ld,alt,idx: ld.dtype == alt.dtype and validate_index(idx)),
# STORE takes a <bufidx, val, gate?> # STORE takes a <bufidx, val, gate?>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store), (UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),

View file

@ -437,7 +437,6 @@ sym = symbolic_flat+PatternMatcher([
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32), ((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))), lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
# ** self folding ** # ** self folding **
(UPat(Ops.DEFINE_REG, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
# x!=0 -> (bool)x # x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where ** # ** where **