This commit is contained in:
George Hotz 2025-08-15 10:11:28 -07:00
commit 155e97045d

View file

@ -202,6 +202,45 @@ pm_rangeify = pm_mops+PatternMatcher([
lambda c,a: UOp(Ops.STORE, src=a.src+c.src[1:]) if c.tag == 1 else None),
])
# 4. remove bufferize
def add_store(x:UOp):
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
assert prod(shape) > 0, f"no zero sized buffers {shape}"
buf = UOp.new_buffer(x.device, prod(shape), x.dtype)
return buf.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0], *rngs)
def add_load(x:UOp, b:UOp, idx:UOp):
if isinstance(x.dtype, PtrDType): return None
return x.replace(dtype=x.dtype.ptr(b.size)).load()
def add_load_on_store(x:UOp, st:UOp):
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
b = st.src[0].src[0]
assert b.op is Ops.BUFFER
return b.shrink(((0,prod(shape)),)).reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=b.size)).load(st)
pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), add_store),
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
])
# 5. create pointers
def debuf(ctx, b:UOp):
ret = UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx[0])
ctx[0] += 1
return ret
pm_debuf = PatternMatcher([
(UPat(Ops.BUFFER, name="b"), debuf),
# HACK: consts shouldn't have srcs by here
(UPat(Ops.CONST, name="x"), lambda x: x.replace(src=()) if len(x.src) else None),
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = {sink:sink}
@ -212,8 +251,13 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="rangeify")
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Rangeify Graph")
rsink = tensor_map[sink]
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
rsink = graph_rewrite(rsink, pm_reduce, ctx=ReduceContext(), name="remove reduce")
rsink = graph_rewrite(rsink, pm_add_buffers, name="buffer", bottom_up=True)
rsink = graph_rewrite(rsink, pm_debuf, ctx=[0], name="debuf", bottom_up=True)
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
rsink = apply_rewrites(tensor_map[sink], rewrites_for_linearizer)
rsink = apply_rewrites(rsink, rewrites_for_linearizer)
from tinygrad.renderer.cstyle import CStyleLanguage
src = CStyleLanguage().render(rsink.arg.lst)
print(src)