mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cleanups + fix nir
This commit is contained in:
parent
396d3f441a
commit
ecf49474eb
5 changed files with 5 additions and 8 deletions
|
|
@ -44,7 +44,7 @@ base_rewrite = PatternMatcher([
|
|||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
|
|
|
|||
|
|
@ -76,14 +76,14 @@ base_rewrite = PatternMatcher([
|
|||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted("idx"), UPat.var("alt"), UPat.var("mask")), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
||||
lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
||||
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
||||
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"),
|
||||
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
||||
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val"))),
|
||||
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ ptx_matcher = PatternMatcher([
|
|||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8)))),
|
||||
# indexing on PTX is in uint64, we do the math while it's still in the graph
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op", allow_any_len=True), lambda buf,idx,op:
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
|
||||
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
|
||||
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
|
|
|
|||
|
|
@ -237,9 +237,6 @@ tensor_spec = PatternMatcher([
|
|||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||
|
||||
# END closes ranges
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue