This commit is contained in:
George Hotz 2025-08-16 14:00:14 -07:00
commit 707d8d9d72
3 changed files with 55 additions and 23 deletions

View file

@ -89,7 +89,7 @@ class TestRangeify(unittest.TestCase):
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).contiguous().conv2d(w2).realize()
def test_transformer_ffn(self):
def test_ffn(self):
from tinygrad.apps.llm import TransformerBlock
from tinygrad import nn
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)

View file

@ -422,7 +422,10 @@ def td_elementwise(ctx, e:UOp):
out_rng = []
need_merge = False
for i,r in enumerate(zip(*rngs)):
if all_same(r):
r = [x for x in r if x is not UOp.const(dtypes.int, 0)]
if len(r) == 0:
out_rng.append(UOp.const(dtypes.int, 0))
elif all_same(r):
out_rng.append(r[0])
else:
out_rng.append(new_range(ctx, shps[0][i]))
@ -431,30 +434,53 @@ def td_elementwise(ctx, e:UOp):
new_src = []
for u in e.src:
assert u.op is Ops.INDEX
mm = []
for i,idx in enumerate(u.src[1:]):
if idx is not out_rng[i]: mm.append(UOp(Ops.MERGE, src=(idx, out_rng[i])))
new_src.append(UOp(Ops.MBLOCK, u.dtype, (u.src[0],)+tuple(mm)))
out = u.src[0]
for i,idx in list(enumerate(u.src[1:]))[::-1]:
if idx is not out_rng[i] and idx is not UOp.const(dtypes.int, 0):
out = UOp(Ops.MERGE, out.dtype, src=(out, idx, out_rng[i]))
new_src.append(out)
#mm = []
#for i,idx in enumerate(u.src[1:]):
# if idx is not out_rng[i] and idx is not UOp.const(dtypes.int, 0):
# mm.append(UOp(Ops.MERGE, src=(idx, out_rng[i])))
#new_src.append(UOp(Ops.MBLOCK, u.dtype, (u.src[0],)+tuple(mm)))
else:
new_src = list(e.src)
new_src = list([x.src[0] for x in e.src])
return e.replace(src=tuple(new_src)).index(*out_rng, arg=shps[0])
def td_shrink(idx:UOp, r:UOp):
ret = []
shp = []
for u,(s,e),shape in zip(idx.src[1:], r.arg, idx.arg):
assert s == 0
if u.vmax >= e: u = u.minimum(e)
ret.append(u)
shp.append(min(shape, e))
return idx.src[0].index(*ret, dtype=idx.dtype, arg=tuple(shp))
pm_td_rangeify = PatternMatcher([
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
b.replace(tag=1).index(nr:=new_range(ctx, b.size), dtype=b.dtype.ptr(size=b.size)).load().index(nr, arg=(b.size,)) if b.tag is None else None),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="c"), lambda c: c.replace(src=()).index(arg=())),
(UPat(Ops.RESHAPE, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_reshape),
(UPat(Ops.SHRINK, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_shrink),
(UPat(Ops.PERMUTE, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
lambda r,idx: idx.src[0].index(*[idx.src[1+p] for p in r.arg], dtype=idx.dtype, arg=tuple(idx.arg[p] for p in r.arg))),
(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
lambda r,idx,ctx: idx.src[0].index(*[u if u.vmax+1==s else UOp.const(dtypes.int, 0) for u,s in zip(idx.src[1:], r.arg)],
dtype=idx.dtype, arg=r.arg)),
# 0s are already in place for EXPAND
(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"), lambda r,idx: idx.replace(arg=r.arg)),
(UPat(GroupOp.Elementwise, src=UPat(Ops.INDEX), name="e"), td_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
lambda r,idx: UOp(Ops.REDUCE, r.dtype, (idx.src[0],)+tuple([x for i,x in enumerate(idx.src[1:]) if i in r.arg[1]]),
r.arg[0]).index(*[x for i,x in enumerate(idx.src[1:]) if i not in r.arg[1]],
r.arg[0]).index(*[x if i not in r.arg[1] else UOp.const(dtypes.int, 0) for i,x in enumerate(idx.src[1:])],
arg=tuple([s if i not in r.arg[1] else 1 for i,s in enumerate(idx.arg)]))),
])
def remove_merge(m):
return UOp(Ops.BUFFERIZE, m.dtype, m.src[0:2], arg=m.device).index(m.src[2])
no_merge = PatternMatcher([
(UPat(Ops.MERGE, name="m"), remove_merge),
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest")
@ -462,19 +488,22 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
ctx = RContext()
rsink = graph_rewrite(rsink, pm_td_rangeify, ctx=ctx, name="td rangeify")
# find IDIV on RANGE to split
reps = {}
for u in rsink.toposort():
if u.op is Ops.IDIV and u.src[0].op is Ops.RANGE and u.src[1].op is Ops.CONST:
r = u.src[0].vmax+1
c = u.src[1].arg
if r%c == 0:
reps[u.src[0]] = new_range(ctx, r//c)*c + new_range(ctx, c)
print(len(reps))
rsink = rsink.substitute(reps)
rsink = graph_rewrite(rsink, sym, name="symbolic")
# find MOD on RANGE to split
while 1:
reps = {}
for u in rsink.toposort():
if u.op is Ops.MOD and u.src[0].op is Ops.RANGE and u.src[1].op is Ops.CONST:
r = u.src[0].vmax+1
c = u.src[1].arg
if r%c == 0:
reps[u.src[0]] = new_range(ctx, r//c)*c + new_range(ctx, c)
print(len(reps))
if len(reps) == 0: break
rsink = rsink.substitute(reps)
rsink = graph_rewrite(rsink, sym, name="symbolic")
for i in range(0):
print("loop")
real_rngs = rsink.ranges.copy()
@ -499,6 +528,9 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
rew[u] = k
rsink = rsink.substitute(rew)
rsink = graph_rewrite(rsink, no_merge, name="remove merge")
rsink = graph_rewrite(rsink, pm_add_buffers, bottom_up=True, name="add buffers")
"""
realize_map = {}
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")

View file

@ -192,7 +192,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def ranges(self) -> dict[UOp, None]:
if self.op is Ops.RANGE: return {self:None}
if self.op is Ops.MBLOCK: return {x.src[1]:None for x in self.src[1:]}
#if self.op is Ops.MBLOCK: return {x.src[1]:None for x in self.src[1:]}
if self.op in {Ops.BUFFERIZE, Ops.REDUCE}:
ret = self.src[0].ranges.copy()
for s in self.src[1:]: