Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
5e433685b1 reorder 2025-07-27 19:21:08 -07:00
George Hotz
96c019a59c full gen 2025-07-27 18:46:24 -07:00
George Hotz
938fe98daf more reduce into 2025-07-27 18:34:52 -07:00
George Hotz
3ec3c27b83 hack on lowerer to make loops 2025-07-27 11:55:26 -07:00
7 changed files with 110 additions and 23 deletions

View file

@ -29,6 +29,8 @@ def hl_spec_kernel3():
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,)))
acc = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM * nbIterWaveN * TN, AddrSpace.REG), arg=2) \
.view(ShapeTracker.from_shape((nbIterWaveM * TM * nbIterWaveN * TN,)))
# shape buffers. TODO: permutes
full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)
@ -40,9 +42,14 @@ def hl_spec_kernel3():
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
#out = (a.load() * b.load()).r(Ops.ADD, (8, 9))
out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9))
#out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
acc = acc.reshape((1, nbIterWaveM, 1, TM, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape[:-2]+(1,1))
#out = a.load() * b.load()
#out = As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))
out = A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))
#out = out.r(Ops.ADD, (8, 9))
out = UOp(Ops.REDUCE_INTO, out.dtype, (acc, out), Ops.ADD)
axis_types = (
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
@ -51,6 +58,7 @@ def hl_spec_kernel3():
from tinygrad.opt.kernel import axis_colors
shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
print(shape)
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types))
sink = graph_rewrite(sink, merge_views)
return sink
@ -144,8 +152,8 @@ def hand_spec_kernel3():
# do the GEMM math
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 9)
yt = UOp.range(dtypes.int, TM, 10)
yt = UOp.range(dtypes.int, TM, 9)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
xt = UOp.range(dtypes.int, TN, 11)
x = iterWaveN * TN + xt
y = iterWaveM * TM + yt
@ -155,8 +163,8 @@ def hand_spec_kernel3():
# store c_regs into c
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 12)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 13)
yt = UOp.range(dtypes.int, TM, 14)
yt = UOp.range(dtypes.int, TM, 13)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 14)
xt = UOp.range(dtypes.int, TN, 15)
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
@ -170,6 +178,7 @@ if __name__ == "__main__":
hprg = hl_spec_kernel3() if getenv("HL") else hand_spec_kernel3()
prg = get_program(hprg, Device.default.renderer)
print(prg.src)
exit(0)
hrunner = CompiledRunner(prg)
a = Tensor.randn(N, N).realize()

View file

@ -28,5 +28,15 @@ class TestDefineReg(unittest.TestCase):
@unittest.skipIf(getenv("PTX"), "ptx needs regs to be unrolled")
def test_simple_loop(self): self.test_simple(AxisType.LOOP)
def test_reduce_into(self):
N = 16
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N*N), arg=0).view(ShapeTracker.from_shape((N,N,1)))
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N*N), arg=1).view(ShapeTracker.from_shape((N,N,N)))
a_reg = UOp(Ops.DEFINE_REG, dtypes.float.ptr(N, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((N,N,1), (0,1,0)))
ret = UOp(Ops.REDUCE_INTO, dtypes.float, src=(a_reg, a.load()), arg=Ops.ADD)
sink = b.store(ret).sink(arg=KernelInfo(name="regcopy"))
prg = get_program(sink, Device.default.renderer)
print(prg.src)
if __name__ == '__main__':
unittest.main()

View file

@ -2,20 +2,34 @@
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
from tinygrad.helpers import prod, partition, flatten
# ***** indexing *****
@dataclass
class IndexContext:
axis_types: list[AxisType]
idxs: list[UOp]
ridxs: list[UOp]
start: int = 0
#ridxs: list[UOp]
def shape_to_idx(s, axis_types, start=0, allow_unroll=False):
idxs = []
for i, (s, at) in enumerate(zip(s, axis_types)):
if at in (AxisType.UPCAST, AxisType.UNROLL) and allow_unroll:
assert isinstance(s, int), "needs to be int to upcast/unroll"
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s), tuple(range(s))),), ((i,s),)))
else:
# all others are RANGES
idxs.append(UOp(Ops.RANGE, dtypes.int, (sint_to_uop(s),), start+i))
return idxs
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
if len(ast.full_shape) != len(axis_types): axis_types = (AxisType.LOOP,)*len(ast.full_shape)
"""
# indexes
idxs = []
for i, (s, at) in enumerate(zip(ast.full_shape, axis_types)):
@ -31,34 +45,85 @@ def get_index(ast:UOp) -> IndexContext:
for i, (s, at) in enumerate(zip(ast.full_shape, axis_types)):
if at == AxisType.GROUP_REDUCE:
ridxs[i] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(s),), 1000+i)
"""
return IndexContext(idxs, ridxs)
return IndexContext(axis_types, []) #idxs, ridxs)
# ***** lowering (given index) *****
def lower_reduce_into(ctx: IndexContext, x: UOp):
# x.src[1] is the big
axis_arg = [i for i,(s0,s1) in enumerate(zip(x.src[0].shape, x.src[1].shape)) if s0 != s1]
reduce_idxs = shape_to_idx([s for i,s in enumerate(x.src[1].shape) if i in axis_arg],
[s for i,s in enumerate(ctx.axis_types) if i in axis_arg], ctx.start)
print("reduce into", [x.arg for x in reduce_idxs])
new_idxs = shape_to_idx(x.src[0].shape[:-len(reduce_idxs)], ctx.axis_types[:-len(reduce_idxs)], ctx.start+len(reduce_idxs)) + reduce_idxs
idx, valid = x.src[0].arg.to_indexed_uops(new_idxs)
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
real_new_idxs = []
for i in range(len(x.src[0].shape)):
if new_idxs[i] in used_idxs or new_idxs[i] in reduce_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
else: real_new_idxs.append(ctx.idxs[i])
non_replaced = [x for x in real_new_idxs if x not in ctx.idxs]
acc = x.src[0].load(*reduce_idxs).alu(x.arg, x.src[1])
lc = IndexContext(ctx.axis_types, tuple(real_new_idxs), ctx.start+len(new_idxs))
from tinygrad.codegen.lowerer import pm_lowerer # TODO: better way to do this?
ret = graph_rewrite(acc, pm_lowerer, lc, name="subreduce", bottom_up=True)
ctx.start = lc.start
red = x.src[0].src[0].index(idx, valid).store(ret, *used_idxs, *reduce_idxs)
return x.src[0].load(red)
def lower_reduce_axis(ctx: IndexContext, x: UOp):
new_idx = shape_to_idx([s for i,s in enumerate(x.src[0].shape) if i in x.axis_arg],
[s for i,s in enumerate(ctx.axis_types) if i in x.axis_arg], ctx.start)
full_new_idx = list(ctx.idxs)
for i,s in zip(x.axis_arg, new_idx):
#assert len(full_new_idx) == i, f"len(full_new_idx) = {len(full_new_idx)}, trying to place {i}"
#full_new_idx.append(s)
full_new_idx[i] = s
lc = IndexContext(ctx.axis_types, tuple(full_new_idx), ctx.start+len(new_idx))
from tinygrad.codegen.lowerer import pm_lowerer # TODO: better way to do this?
ret = graph_rewrite(x.src[0], pm_lowerer, lc, name="subreduce", bottom_up=True)
ctx.start = lc.start
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
reduce_range, reduce_expand = partition([full_new_idx[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), x.arg[0])
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
barrier = tuple([y.barrier() if buf.op is Ops.DEFINE_LOCAL else y for y in x.src[1:]])
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
print("lower_load", [x.arg for x in ctx.idxs])
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + x.src[1:])
#barrier = tuple([y.barrier() if buf.op is Ops.DEFINE_LOCAL else y for y in x.src[1:]])
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.GLOBAL:
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE])
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
idx, valid = x.st_arg.to_indexed_uops(new_idxs)
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
real_new_idxs = []
for i in range(len(x.src[0].shape)):
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i:
real_new_idxs.append(new_idxs[i])
else:
real_new_idxs.append(ctx.idxs[i])
print("got", len(real_new_idxs), len(used_idxs))
lc = IndexContext(ctx.axis_types, tuple(real_new_idxs), ctx.start+len(new_idxs))
from tinygrad.codegen.lowerer import pm_lowerer # TODO: better way to do this?
stored = graph_rewrite(x.src[1], pm_lowerer, lc, name="substore", bottom_up=True)
ctx.start = lc.start
return buf.index(idx, valid).store(stored, *[x for x in used_idxs if x.op is Ops.RANGE])
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
if all(x.mask is None for x in view.arg.views): return c
@ -73,6 +138,7 @@ pm_lowerer = PatternMatcher([
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
# reduce/view_const
(UPat(Ops.REDUCE_INTO, name="x"), lower_reduce_into),
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"), lower_const),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed

View file

@ -40,7 +40,7 @@ class Ops(FastEnum):
SPECIAL = auto()
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
REDUCE_AXIS = auto(); REDUCE_INTO = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
# optimization helper ops
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702

View file

@ -138,6 +138,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def st(self) -> ShapeTracker|None:
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
from tinygrad.shape.shapetracker import ShapeTracker
if self.op is Ops.REDUCE_INTO: return self.src[0].st
# VIEW and MovementOps define a new ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)

View file

@ -198,7 +198,7 @@ spec = PatternMatcher([
# NOTE: for testing, we let sinks be anything
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST, Ops.REDUCE_INTO)), lambda: True),
# PTX LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),

View file

@ -13,6 +13,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.REDUCE_INTO: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",