mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
replace INDEX with SLICE
This commit is contained in:
parent
156a4438d9
commit
d0f956341c
7 changed files with 30 additions and 15 deletions
|
|
@ -115,7 +115,7 @@ pm_linearize_cleanups = PatternMatcher([
|
|||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.SLICE), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -283,6 +283,12 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
# rewrite non-image INDEX to SLICE
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: None if isinstance(x.src[0].dtype, ImageDType) else \
|
||||
UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
|
||||
# rewrite CAST on SLICE to just SLICE
|
||||
(UPat(Ops.SLICE, name="bv").cast(name="x"),
|
||||
lambda bv,x: bv.replace(dtype=x.dtype, arg=0 if x.dtype.count == 1 else x.dtype.count))
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
|
|||
|
|
@ -43,8 +43,10 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
||||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# slice is ptr arithmetic
|
||||
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var('idx')), name="x"),
|
||||
lambda ctx,buf,idx,x: ctx.render_cast(x.dtype, f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})")),
|
||||
# new load/store
|
||||
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
|
|
|
|||
|
|
@ -101,16 +101,18 @@ class PythonProgram:
|
|||
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
elif uop is Ops.SLICE:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
values[i] = ret
|
||||
elif uop is Ops.INDEX:
|
||||
assert isinstance(src_dtypes[0], ImageDType), "only image INDEX is supported"
|
||||
ret:list = []
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
values[i] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.BINARY: return (len(self.arg),)
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.SLICE:
|
||||
if self.arg == 0: return ()
|
||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
return (self.arg,)
|
||||
|
|
|
|||
|
|
@ -99,11 +99,11 @@ spec_shared = PatternMatcher([
|
|||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().load(), validate_index),
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat()), validate_index),
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat()), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
|
||||
# STORE in tensor graph: store a value into a target
|
||||
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),
|
||||
|
|
@ -200,6 +200,10 @@ spec_program = PatternMatcher([
|
|||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
||||
# buffer view in program, Image only for ImageDType
|
||||
(UPat(Ops.SLICE), lambda: True),
|
||||
(UPat(Ops.INDEX, name="idx"), lambda idx: isinstance(idx.src[0].dtype, ImageDType)),
|
||||
|
||||
# movement ops are not allowed in programs
|
||||
(UPat(GroupOp.Movement), lambda: False),
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.SLICE: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.SLICE: "#a2c148", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
Ops.LINEAR: "#7DF4FF",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue