mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
ish
This commit is contained in:
parent
05f04bbcc3
commit
155e97045d
1 changed files with 45 additions and 1 deletions
|
|
@ -202,6 +202,45 @@ pm_rangeify = pm_mops+PatternMatcher([
|
|||
lambda c,a: UOp(Ops.STORE, src=a.src+c.src[1:]) if c.tag == 1 else None),
|
||||
])
|
||||
|
||||
# 4. remove bufferize
|
||||
|
||||
def add_store(x:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
assert prod(shape) > 0, f"no zero sized buffers {shape}"
|
||||
buf = UOp.new_buffer(x.device, prod(shape), x.dtype)
|
||||
return buf.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0], *rngs)
|
||||
|
||||
def add_load(x:UOp, b:UOp, idx:UOp):
|
||||
if isinstance(x.dtype, PtrDType): return None
|
||||
return x.replace(dtype=x.dtype.ptr(b.size)).load()
|
||||
|
||||
def add_load_on_store(x:UOp, st:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
b = st.src[0].src[0]
|
||||
assert b.op is Ops.BUFFER
|
||||
return b.shrink(((0,prod(shape)),)).reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=b.size)).load(st)
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), add_store),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
|
||||
])
|
||||
|
||||
# 5. create pointers
|
||||
|
||||
def debuf(ctx, b:UOp):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx[0])
|
||||
ctx[0] += 1
|
||||
return ret
|
||||
|
||||
pm_debuf = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="b"), debuf),
|
||||
# HACK: consts shouldn't have srcs by here
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(src=()) if len(x.src) else None),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tensor_map = {sink:sink}
|
||||
|
|
@ -212,8 +251,13 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="rangeify")
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Rangeify Graph")
|
||||
|
||||
rsink = tensor_map[sink]
|
||||
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
|
||||
rsink = graph_rewrite(rsink, pm_reduce, ctx=ReduceContext(), name="remove reduce")
|
||||
rsink = graph_rewrite(rsink, pm_add_buffers, name="buffer", bottom_up=True)
|
||||
rsink = graph_rewrite(rsink, pm_debuf, ctx=[0], name="debuf", bottom_up=True)
|
||||
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
|
||||
rsink = apply_rewrites(tensor_map[sink], rewrites_for_linearizer)
|
||||
rsink = apply_rewrites(rsink, rewrites_for_linearizer)
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
src = CStyleLanguage().render(rsink.arg.lst)
|
||||
print(src)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue