mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
prefetch
This commit is contained in:
parent
40767e8f92
commit
707d8d9d72
3 changed files with 55 additions and 23 deletions
|
|
@ -89,7 +89,7 @@ class TestRangeify(unittest.TestCase):
|
|||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).contiguous().conv2d(w2).realize()
|
||||
|
||||
def test_transformer_ffn(self):
|
||||
def test_ffn(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
from tinygrad import nn
|
||||
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
|
||||
|
|
|
|||
|
|
@ -422,7 +422,10 @@ def td_elementwise(ctx, e:UOp):
|
|||
out_rng = []
|
||||
need_merge = False
|
||||
for i,r in enumerate(zip(*rngs)):
|
||||
if all_same(r):
|
||||
r = [x for x in r if x is not UOp.const(dtypes.int, 0)]
|
||||
if len(r) == 0:
|
||||
out_rng.append(UOp.const(dtypes.int, 0))
|
||||
elif all_same(r):
|
||||
out_rng.append(r[0])
|
||||
else:
|
||||
out_rng.append(new_range(ctx, shps[0][i]))
|
||||
|
|
@ -431,30 +434,53 @@ def td_elementwise(ctx, e:UOp):
|
|||
new_src = []
|
||||
for u in e.src:
|
||||
assert u.op is Ops.INDEX
|
||||
mm = []
|
||||
for i,idx in enumerate(u.src[1:]):
|
||||
if idx is not out_rng[i]: mm.append(UOp(Ops.MERGE, src=(idx, out_rng[i])))
|
||||
new_src.append(UOp(Ops.MBLOCK, u.dtype, (u.src[0],)+tuple(mm)))
|
||||
out = u.src[0]
|
||||
for i,idx in list(enumerate(u.src[1:]))[::-1]:
|
||||
if idx is not out_rng[i] and idx is not UOp.const(dtypes.int, 0):
|
||||
out = UOp(Ops.MERGE, out.dtype, src=(out, idx, out_rng[i]))
|
||||
new_src.append(out)
|
||||
#mm = []
|
||||
#for i,idx in enumerate(u.src[1:]):
|
||||
# if idx is not out_rng[i] and idx is not UOp.const(dtypes.int, 0):
|
||||
# mm.append(UOp(Ops.MERGE, src=(idx, out_rng[i])))
|
||||
#new_src.append(UOp(Ops.MBLOCK, u.dtype, (u.src[0],)+tuple(mm)))
|
||||
else:
|
||||
new_src = list(e.src)
|
||||
new_src = list([x.src[0] for x in e.src])
|
||||
return e.replace(src=tuple(new_src)).index(*out_rng, arg=shps[0])
|
||||
|
||||
def td_shrink(idx:UOp, r:UOp):
|
||||
ret = []
|
||||
shp = []
|
||||
for u,(s,e),shape in zip(idx.src[1:], r.arg, idx.arg):
|
||||
assert s == 0
|
||||
if u.vmax >= e: u = u.minimum(e)
|
||||
ret.append(u)
|
||||
shp.append(min(shape, e))
|
||||
return idx.src[0].index(*ret, dtype=idx.dtype, arg=tuple(shp))
|
||||
pm_td_rangeify = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
||||
b.replace(tag=1).index(nr:=new_range(ctx, b.size), dtype=b.dtype.ptr(size=b.size)).load().index(nr, arg=(b.size,)) if b.tag is None else None),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="c"), lambda c: c.replace(src=()).index(arg=())),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_reshape),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_shrink),
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
|
||||
lambda r,idx: idx.src[0].index(*[idx.src[1+p] for p in r.arg], dtype=idx.dtype, arg=tuple(idx.arg[p] for p in r.arg))),
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
|
||||
lambda r,idx,ctx: idx.src[0].index(*[u if u.vmax+1==s else UOp.const(dtypes.int, 0) for u,s in zip(idx.src[1:], r.arg)],
|
||||
dtype=idx.dtype, arg=r.arg)),
|
||||
# 0s are already in place for EXPAND
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"), lambda r,idx: idx.replace(arg=r.arg)),
|
||||
(UPat(GroupOp.Elementwise, src=UPat(Ops.INDEX), name="e"), td_elementwise),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
|
||||
lambda r,idx: UOp(Ops.REDUCE, r.dtype, (idx.src[0],)+tuple([x for i,x in enumerate(idx.src[1:]) if i in r.arg[1]]),
|
||||
r.arg[0]).index(*[x for i,x in enumerate(idx.src[1:]) if i not in r.arg[1]],
|
||||
r.arg[0]).index(*[x if i not in r.arg[1] else UOp.const(dtypes.int, 0) for i,x in enumerate(idx.src[1:])],
|
||||
arg=tuple([s if i not in r.arg[1] else 1 for i,s in enumerate(idx.arg)]))),
|
||||
])
|
||||
|
||||
def remove_merge(m):
|
||||
return UOp(Ops.BUFFERIZE, m.dtype, m.src[0:2], arg=m.device).index(m.src[2])
|
||||
|
||||
no_merge = PatternMatcher([
|
||||
(UPat(Ops.MERGE, name="m"), remove_merge),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest")
|
||||
|
|
@ -462,19 +488,22 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
|
||||
ctx = RContext()
|
||||
rsink = graph_rewrite(rsink, pm_td_rangeify, ctx=ctx, name="td rangeify")
|
||||
|
||||
# find IDIV on RANGE to split
|
||||
reps = {}
|
||||
for u in rsink.toposort():
|
||||
if u.op is Ops.IDIV and u.src[0].op is Ops.RANGE and u.src[1].op is Ops.CONST:
|
||||
r = u.src[0].vmax+1
|
||||
c = u.src[1].arg
|
||||
if r%c == 0:
|
||||
reps[u.src[0]] = new_range(ctx, r//c)*c + new_range(ctx, c)
|
||||
print(len(reps))
|
||||
rsink = rsink.substitute(reps)
|
||||
rsink = graph_rewrite(rsink, sym, name="symbolic")
|
||||
|
||||
# find MOD on RANGE to split
|
||||
while 1:
|
||||
reps = {}
|
||||
for u in rsink.toposort():
|
||||
if u.op is Ops.MOD and u.src[0].op is Ops.RANGE and u.src[1].op is Ops.CONST:
|
||||
r = u.src[0].vmax+1
|
||||
c = u.src[1].arg
|
||||
if r%c == 0:
|
||||
reps[u.src[0]] = new_range(ctx, r//c)*c + new_range(ctx, c)
|
||||
print(len(reps))
|
||||
if len(reps) == 0: break
|
||||
rsink = rsink.substitute(reps)
|
||||
rsink = graph_rewrite(rsink, sym, name="symbolic")
|
||||
|
||||
for i in range(0):
|
||||
print("loop")
|
||||
real_rngs = rsink.ranges.copy()
|
||||
|
|
@ -499,6 +528,9 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
rew[u] = k
|
||||
rsink = rsink.substitute(rew)
|
||||
|
||||
rsink = graph_rewrite(rsink, no_merge, name="remove merge")
|
||||
rsink = graph_rewrite(rsink, pm_add_buffers, bottom_up=True, name="add buffers")
|
||||
|
||||
"""
|
||||
realize_map = {}
|
||||
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
@functools.cached_property
|
||||
def ranges(self) -> dict[UOp, None]:
|
||||
if self.op is Ops.RANGE: return {self:None}
|
||||
if self.op is Ops.MBLOCK: return {x.src[1]:None for x in self.src[1:]}
|
||||
#if self.op is Ops.MBLOCK: return {x.src[1]:None for x in self.src[1:]}
|
||||
if self.op in {Ops.BUFFERIZE, Ops.REDUCE}:
|
||||
ret = self.src[0].ranges.copy()
|
||||
for s in self.src[1:]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue