mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
no GEP in program
This commit is contained in:
parent
7b951e691e
commit
10c2a50e79
3 changed files with 12 additions and 7 deletions
|
|
@ -32,6 +32,8 @@ pm_index_is_shrink = PatternMatcher([
|
|||
# rewrite non-image INDEX to SHRINK
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
|
||||
UOp(Ops.SHRINK, dtype=buf.dtype.base, src=(buf, idx, UOp.const(dtypes.int, x.dtype.count)))),
|
||||
# rewrite GEP to INDEX
|
||||
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg),), arg=None)),
|
||||
# remove all vec dtypes
|
||||
(UPat(GroupOp.All, name="x"), lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
|
|||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
|
||||
def render_index(ctx,buf,idx):
|
||||
if buf.addrspace == AddrSpace.REG:
|
||||
assert idx.op is Ops.CONST
|
||||
return ctx[buf] + (f"[{idx.arg}]" if buf.max_numel() > ctx.gep_arr_threshold else f".{'xyzwabcd'[idx.arg]}")
|
||||
else:
|
||||
return f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
|
|
@ -44,10 +50,8 @@ base_rewrite = PatternMatcher([
|
|||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# SHRINK/INDEX
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,buf,idx:
|
||||
f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), lambda ctx,buf,idx:
|
||||
f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), render_index),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), render_index),
|
||||
# new load/store
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), name="x"), lambda ctx,x,bidx: ctx.render_access(bidx, ctx.render_dtype_with_shape(x))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")), name="x"),
|
||||
|
|
@ -58,8 +62,6 @@ base_rewrite = PatternMatcher([
|
|||
# TODO: look for left-associative
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ from tinygrad.dtype import dtypes, AddrSpace
|
|||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
Ops.INDEX: "#D8F9E4", Ops.STACK: "#D8F9E4",
|
||||
Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.SLICE: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue