mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
lowerer_ha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e433685b1 | ||
|
|
96c019a59c | ||
|
|
938fe98daf | ||
|
|
3ec3c27b83 |
7 changed files with 110 additions and 23 deletions
|
|
@ -29,6 +29,8 @@ def hl_spec_kernel3():
|
|||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,)))
|
||||
acc = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM * nbIterWaveN * TN, AddrSpace.REG), arg=2) \
|
||||
.view(ShapeTracker.from_shape((nbIterWaveM * TM * nbIterWaveN * TN,)))
|
||||
|
||||
# shape buffers. TODO: permutes
|
||||
full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)
|
||||
|
|
@ -40,9 +42,14 @@ def hl_spec_kernel3():
|
|||
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
|
||||
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
|
||||
|
||||
#out = (a.load() * b.load()).r(Ops.ADD, (8, 9))
|
||||
out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9))
|
||||
#out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
acc = acc.reshape((1, nbIterWaveM, 1, TM, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape[:-2]+(1,1))
|
||||
|
||||
#out = a.load() * b.load()
|
||||
#out = As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))
|
||||
out = A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))
|
||||
|
||||
#out = out.r(Ops.ADD, (8, 9))
|
||||
out = UOp(Ops.REDUCE_INTO, out.dtype, (acc, out), Ops.ADD)
|
||||
|
||||
axis_types = (
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
|
|
@ -51,6 +58,7 @@ def hl_spec_kernel3():
|
|||
|
||||
from tinygrad.opt.kernel import axis_colors
|
||||
shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
|
||||
print(shape)
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
return sink
|
||||
|
|
@ -144,8 +152,8 @@ def hand_spec_kernel3():
|
|||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 9)
|
||||
yt = UOp.range(dtypes.int, TM, 10)
|
||||
yt = UOp.range(dtypes.int, TM, 9)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
|
||||
xt = UOp.range(dtypes.int, TN, 11)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
|
|
@ -155,8 +163,8 @@ def hand_spec_kernel3():
|
|||
|
||||
# store c_regs into c
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 12)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 13)
|
||||
yt = UOp.range(dtypes.int, TM, 14)
|
||||
yt = UOp.range(dtypes.int, TM, 13)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 14)
|
||||
xt = UOp.range(dtypes.int, TN, 15)
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
|
|
@ -170,6 +178,7 @@ if __name__ == "__main__":
|
|||
hprg = hl_spec_kernel3() if getenv("HL") else hand_spec_kernel3()
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
print(prg.src)
|
||||
exit(0)
|
||||
hrunner = CompiledRunner(prg)
|
||||
|
||||
a = Tensor.randn(N, N).realize()
|
||||
|
|
|
|||
|
|
@ -28,5 +28,15 @@ class TestDefineReg(unittest.TestCase):
|
|||
@unittest.skipIf(getenv("PTX"), "ptx needs regs to be unrolled")
|
||||
def test_simple_loop(self): self.test_simple(AxisType.LOOP)
|
||||
|
||||
def test_reduce_into(self):
|
||||
N = 16
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N*N), arg=0).view(ShapeTracker.from_shape((N,N,1)))
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N*N), arg=1).view(ShapeTracker.from_shape((N,N,N)))
|
||||
a_reg = UOp(Ops.DEFINE_REG, dtypes.float.ptr(N, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((N,N,1), (0,1,0)))
|
||||
ret = UOp(Ops.REDUCE_INTO, dtypes.float, src=(a_reg, a.load()), arg=Ops.ADD)
|
||||
sink = b.store(ret).sink(arg=KernelInfo(name="regcopy"))
|
||||
prg = get_program(sink, Device.default.renderer)
|
||||
print(prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -2,20 +2,34 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
|
||||
from tinygrad.helpers import prod, partition, flatten
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
@dataclass
|
||||
class IndexContext:
|
||||
axis_types: list[AxisType]
|
||||
idxs: list[UOp]
|
||||
ridxs: list[UOp]
|
||||
start: int = 0
|
||||
#ridxs: list[UOp]
|
||||
|
||||
def shape_to_idx(s, axis_types, start=0, allow_unroll=False):
|
||||
idxs = []
|
||||
for i, (s, at) in enumerate(zip(s, axis_types)):
|
||||
if at in (AxisType.UPCAST, AxisType.UNROLL) and allow_unroll:
|
||||
assert isinstance(s, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s), tuple(range(s))),), ((i,s),)))
|
||||
else:
|
||||
# all others are RANGES
|
||||
idxs.append(UOp(Ops.RANGE, dtypes.int, (sint_to_uop(s),), start+i))
|
||||
return idxs
|
||||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
if len(ast.full_shape) != len(axis_types): axis_types = (AxisType.LOOP,)*len(ast.full_shape)
|
||||
|
||||
"""
|
||||
# indexes
|
||||
idxs = []
|
||||
for i, (s, at) in enumerate(zip(ast.full_shape, axis_types)):
|
||||
|
|
@ -31,34 +45,85 @@ def get_index(ast:UOp) -> IndexContext:
|
|||
for i, (s, at) in enumerate(zip(ast.full_shape, axis_types)):
|
||||
if at == AxisType.GROUP_REDUCE:
|
||||
ridxs[i] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(s),), 1000+i)
|
||||
"""
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
return IndexContext(axis_types, []) #idxs, ridxs)
|
||||
|
||||
# ***** lowering (given index) *****
|
||||
|
||||
def lower_reduce_into(ctx: IndexContext, x: UOp):
|
||||
# x.src[1] is the big
|
||||
axis_arg = [i for i,(s0,s1) in enumerate(zip(x.src[0].shape, x.src[1].shape)) if s0 != s1]
|
||||
reduce_idxs = shape_to_idx([s for i,s in enumerate(x.src[1].shape) if i in axis_arg],
|
||||
[s for i,s in enumerate(ctx.axis_types) if i in axis_arg], ctx.start)
|
||||
print("reduce into", [x.arg for x in reduce_idxs])
|
||||
new_idxs = shape_to_idx(x.src[0].shape[:-len(reduce_idxs)], ctx.axis_types[:-len(reduce_idxs)], ctx.start+len(reduce_idxs)) + reduce_idxs
|
||||
|
||||
idx, valid = x.src[0].arg.to_indexed_uops(new_idxs)
|
||||
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
|
||||
|
||||
real_new_idxs = []
|
||||
for i in range(len(x.src[0].shape)):
|
||||
if new_idxs[i] in used_idxs or new_idxs[i] in reduce_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
|
||||
else: real_new_idxs.append(ctx.idxs[i])
|
||||
non_replaced = [x for x in real_new_idxs if x not in ctx.idxs]
|
||||
|
||||
acc = x.src[0].load(*reduce_idxs).alu(x.arg, x.src[1])
|
||||
|
||||
lc = IndexContext(ctx.axis_types, tuple(real_new_idxs), ctx.start+len(new_idxs))
|
||||
from tinygrad.codegen.lowerer import pm_lowerer # TODO: better way to do this?
|
||||
ret = graph_rewrite(acc, pm_lowerer, lc, name="subreduce", bottom_up=True)
|
||||
ctx.start = lc.start
|
||||
|
||||
red = x.src[0].src[0].index(idx, valid).store(ret, *used_idxs, *reduce_idxs)
|
||||
return x.src[0].load(red)
|
||||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
new_idx = shape_to_idx([s for i,s in enumerate(x.src[0].shape) if i in x.axis_arg],
|
||||
[s for i,s in enumerate(ctx.axis_types) if i in x.axis_arg], ctx.start)
|
||||
full_new_idx = list(ctx.idxs)
|
||||
for i,s in zip(x.axis_arg, new_idx):
|
||||
#assert len(full_new_idx) == i, f"len(full_new_idx) = {len(full_new_idx)}, trying to place {i}"
|
||||
#full_new_idx.append(s)
|
||||
full_new_idx[i] = s
|
||||
|
||||
lc = IndexContext(ctx.axis_types, tuple(full_new_idx), ctx.start+len(new_idx))
|
||||
from tinygrad.codegen.lowerer import pm_lowerer # TODO: better way to do this?
|
||||
ret = graph_rewrite(x.src[0], pm_lowerer, lc, name="subreduce", bottom_up=True)
|
||||
ctx.start = lc.start
|
||||
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
reduce_range, reduce_expand = partition([full_new_idx[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), x.arg[0])
|
||||
|
||||
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
barrier = tuple([y.barrier() if buf.op is Ops.DEFINE_LOCAL else y for y in x.src[1:]])
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
||||
print("lower_load", [x.arg for x in ctx.idxs])
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + x.src[1:])
|
||||
|
||||
#barrier = tuple([y.barrier() if buf.op is Ops.DEFINE_LOCAL else y for y in x.src[1:]])
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
|
||||
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.GLOBAL:
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE])
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
idx, valid = x.st_arg.to_indexed_uops(new_idxs)
|
||||
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
|
||||
real_new_idxs = []
|
||||
for i in range(len(x.src[0].shape)):
|
||||
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i:
|
||||
real_new_idxs.append(new_idxs[i])
|
||||
else:
|
||||
real_new_idxs.append(ctx.idxs[i])
|
||||
print("got", len(real_new_idxs), len(used_idxs))
|
||||
lc = IndexContext(ctx.axis_types, tuple(real_new_idxs), ctx.start+len(new_idxs))
|
||||
from tinygrad.codegen.lowerer import pm_lowerer # TODO: better way to do this?
|
||||
stored = graph_rewrite(x.src[1], pm_lowerer, lc, name="substore", bottom_up=True)
|
||||
ctx.start = lc.start
|
||||
return buf.index(idx, valid).store(stored, *[x for x in used_idxs if x.op is Ops.RANGE])
|
||||
|
||||
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
|
||||
if all(x.mask is None for x in view.arg.views): return c
|
||||
|
|
@ -73,6 +138,7 @@ pm_lowerer = PatternMatcher([
|
|||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
|
||||
|
||||
# reduce/view_const
|
||||
(UPat(Ops.REDUCE_INTO, name="x"), lower_reduce_into),
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"), lower_const),
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class Ops(FastEnum):
|
|||
SPECIAL = auto()
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
||||
REDUCE_AXIS = auto(); REDUCE_INTO = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
||||
|
||||
# optimization helper ops
|
||||
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def st(self) -> ShapeTracker|None:
|
||||
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
if self.op is Ops.REDUCE_INTO: return self.src[0].st
|
||||
# VIEW and MovementOps define a new ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ spec = PatternMatcher([
|
|||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST, Ops.REDUCE_INTO)), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from tinygrad.dtype import dtypes
|
|||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.REDUCE_INTO: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue