mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
36 commits
master
...
move_gates
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9be9fbc77 |
||
|
|
2ccefa11ec | ||
|
|
95b0a651c2 | ||
|
|
76606eb386 | ||
|
|
e74bf441f0 | ||
|
|
661eb76309 | ||
|
|
a0c04a5e35 | ||
|
|
13e0fbaba6 | ||
|
|
58a09b22ac |
||
|
|
d09ea1d620 | ||
|
|
7a00223bd3 | ||
|
|
5053148502 | ||
|
|
ecf49474eb | ||
|
|
396d3f441a |
||
|
|
6573c103f9 | ||
|
|
fc2a289f61 | ||
|
|
5736eee2f2 | ||
|
|
651279c7ff |
||
|
|
0821bef6b4 | ||
|
|
437205ae03 | ||
|
|
cfefef479b | ||
|
|
5d9431ecb9 | ||
|
|
0f3b12fcd8 | ||
|
|
60c8542320 | ||
|
|
4ec5487ad8 | ||
|
|
995a787d6c | ||
|
|
1b17762030 | ||
|
|
c0f443cf47 |
||
|
|
ff1258feef | ||
|
|
51b13466dd | ||
|
|
416878db9e | ||
|
|
e00b3b4065 | ||
|
|
d810bd2b41 |
||
|
|
09ec34437d | ||
|
|
36383298be | ||
|
|
8f397f5c7c |
12 changed files with 104 additions and 87 deletions
|
|
@ -16,6 +16,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcend
|
|||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
||||
|
|
@ -76,8 +77,13 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
# lower the index dtype to a concrete int. this needs to happen while gates are still present
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
|
||||
# move the gates from index onto the loads and stores
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||
|
||||
# a final symbolic
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
# optional pre matcher
|
||||
|
|
@ -106,8 +112,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
pm_linearize_cleanups = PatternMatcher([
|
||||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
|
||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
|||
return drop_stmt
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
start_idx = start_idx.simplify() # if you don't do this, uop_given_valid may simplify things and this might inf loop
|
||||
idx = uop_given_valid(valid, start_idx)
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
|
||||
|
||||
|
|
@ -281,18 +282,13 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# Where after gated load becomes alt value
|
||||
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
|
||||
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
|
||||
l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
|
||||
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("gate").where(UPat(Ops.LOAD, src=(UPat(), UPat(), UPat.var("gate")), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:
|
||||
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
|
||||
(UPat.var("gate").where(UPat.var("a"), UPat(Ops.LOAD,
|
||||
src=(UPat(), UPat(), UPat.var("gate", dtype=dtypes.bool).logical_not()), name="l").or_casted()), lambda gate,l,a:
|
||||
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
@ -368,7 +364,7 @@ pm_add_loads = PatternMatcher([
|
|||
# add loads to non ptr index
|
||||
(UPat(Ops.INDEX, name="idx"), add_load),
|
||||
# remove loads from stores
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
||||
])
|
||||
|
||||
# make images
|
||||
|
|
|
|||
13
tinygrad/codegen/late/gater.py
Normal file
13
tinygrad/codegen/late/gater.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# this transforms Invalid into gated load/stores
|
||||
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
from tinygrad.dtype import Invalid, dtypes
|
||||
|
||||
pm_move_gates_from_index = PatternMatcher([
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
|
||||
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
|
||||
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
|
||||
# remove hanging weakint casts
|
||||
(UPat.var("buf").index(UPat.var("idx", dtypes.ints).cast()), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
])
|
||||
|
|
@ -44,12 +44,11 @@ base_rewrite = PatternMatcher([
|
|||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
|
||||
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
# TODO: look for left-associative
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
|
|
@ -302,12 +301,12 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
|
||||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), allow_any_len=True),
|
||||
UPat.var("var", dtypes.float.vec(4)))),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),
|
||||
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -76,14 +76,14 @@ base_rewrite = PatternMatcher([
|
|||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("mask"))).or_casted("idx"), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
||||
lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
||||
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
||||
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"),
|
||||
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
||||
|
||||
|
|
|
|||
|
|
@ -128,8 +128,8 @@ class NIRRenderer(Renderer):
|
|||
# load/store bool -> uint8
|
||||
(UPat(Ops.LOAD, dtypes.bool, name="x"),
|
||||
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8)))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
|
||||
# NIR requires shift amount to be 32 bit: https://docs.mesa3d.org/nir/alu.html#nir-alu-op-ishl
|
||||
(UPat((Ops.SHL, Ops.SHR), name="x"), lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint))) if x.src[1].dtype.bitsize != 32 else None),
|
||||
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
|
||||
|
|
@ -146,12 +146,12 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))), UPat.var("val")), allow_any_len=True),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
|
|
@ -268,9 +268,9 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
|
||||
|
||||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val"))),
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val")), allow_any_len=True),
|
||||
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
|
|
|||
|
|
@ -48,10 +48,10 @@ ptx_matcher = PatternMatcher([
|
|||
# load/store bool -> uint8
|
||||
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x"),
|
||||
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8)))),
|
||||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
|
||||
# indexing on PTX is in uint64, we do the math while it's still in the graph
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op", allow_any_len=True), lambda buf,idx,op:
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
|
||||
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
|
||||
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
|
|
@ -102,11 +102,11 @@ string_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
|
||||
# store / gated load / load
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc")), allow_any_len=True), UPat.var("var"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("var")), allow_any_len=True),
|
||||
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
|
||||
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("alt"), UPat.var("gate")), allow_any_len=True),
|
||||
lambda ctx, x, loc, alt, gate, buf: flatten([
|
||||
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
|
|
|
|||
|
|
@ -10,21 +10,20 @@ def sign_extend(val:UOp, sext_am:int):
|
|||
| val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
|
||||
|
||||
# store for char: buf[idx/4] <- (var << (idx%4)*8))
|
||||
def packed_store(bidx:UOp, var:UOp):
|
||||
def packed_store(bidx:UOp, var:UOp, gate:UOp|None=None):
|
||||
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
|
||||
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
|
||||
# preserve valid condition (bidx.src[2]) if it exists for gated stores
|
||||
idx_src = (bidx.src[0], div_idx) if len(bidx.src) == 2 else (bidx.src[0], div_idx, bidx.src[2])
|
||||
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, idx_src), dtype=dtypes.uint32)
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, idx_src), (buf & wmask) | new_v)
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
|
||||
buf = UOp.load(idx, *((UOp.const(dtypes.uint32, 0), gate) if gate is not None else ()), dtype=dtypes.uint32)
|
||||
return UOp.store(idx, (buf & wmask) | new_v, *((gate,) if gate is not None else ()))
|
||||
|
||||
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
|
||||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
|
||||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|None=None):
|
||||
elems, mask = 4//dtype.itemsize, _mask(dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2]) if var is not None else (bidx.src[0], div_idx))
|
||||
load = UOp.load(idx, *([var] if var is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
|
||||
load = UOp.load(idx, *((var, gate) if var is not None and gate is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
|
|
@ -41,10 +40,12 @@ wgsl_matcher = PatternMatcher([
|
|||
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
||||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), allow_any_len=True, name="l"),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
|
||||
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
|
|
@ -82,14 +83,14 @@ class WGSLRenderer(CStyleLanguage):
|
|||
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),
|
||||
lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
|
||||
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
|
||||
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
||||
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
]) + base_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ def _load(m, i, dtype: DType):
|
|||
return from_storage_scalar(m[i], dtype)
|
||||
|
||||
def load(inp, j, dtype: DType):
|
||||
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
|
||||
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
|
||||
if len(inp) >= 3: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x),default,gate in zip(*inp[:3])]
|
||||
return [_load(m, x+j if x is not None else None, dtype) for m,x in inp[0]]
|
||||
|
||||
def _store(m, i, v, dtype: DType):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
|
|
@ -67,8 +67,9 @@ class PythonProgram:
|
|||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
if uop is Ops.STORE:
|
||||
store_gate = src_values[2] if len(src_values) >= 3 else [True] * warp_size
|
||||
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
|
||||
for (m,o,g),v in zip(src_values[0], val):
|
||||
for (m,o),v,g in zip(src_values[0], val, store_gate):
|
||||
if g: _store(m, o+j, v, src_dtypes[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
|
|
@ -91,6 +92,7 @@ class PythonProgram:
|
|||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
if len(src_values) != 2: raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
|
||||
|
|
@ -98,7 +100,7 @@ class PythonProgram:
|
|||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
|
||||
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
|
||||
values[i] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
elif uop is Ops.RANGE:
|
||||
|
|
|
|||
|
|
@ -418,13 +418,15 @@ def f2f_clamp(val:UOp, dt:DType) -> UOp:
|
|||
# FIXME: CMPLT of nan is undefined
|
||||
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
|
||||
|
||||
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
|
||||
def f2f_load(x:UOp, fr:DType, to:DType) -> UOp:
|
||||
if (n:=x.dtype.count) == 1:
|
||||
return f2f(x.replace(src=(x.src[0],)+((x.src[1].cast(f2f_dt[fr]), x.src[2]) if len(x.src) >= 3 else ()), dtype=f2f_dt[fr]), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),) + \
|
||||
((x.src[1].gep(i).cast(f2f_dt[fr]), x.src[2]) if len(x.src) >= 3 else ())), fr, to) for i in range(n)))
|
||||
|
||||
def f2f_store(st, idx, val, fr:DType, to:DType):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
|
||||
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))) for i in range(n)))
|
||||
def f2f_store(st:UOp, idx, val, fr:DType, to:DType):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr))+st.src[2:])
|
||||
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))+st.src[2:]) for i in range(n)))
|
||||
|
||||
# ***** decomposition patterns *****
|
||||
|
||||
|
|
@ -509,8 +511,8 @@ pm_long_decomp = PatternMatcher([
|
|||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
|
||||
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
|
||||
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), allow_any_len=True, name='st'), lambda st,idx,val:
|
||||
st.replace(src=(reindex(idx, 0), val.rtag(0))+st.src[2:]).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
|
||||
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt))),
|
||||
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
|
||||
|
|
@ -522,7 +524,10 @@ pm_long_decomp = PatternMatcher([
|
|||
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'), UPat.var('alt'), UPat.var('gate')), name='x'), lambda x,idx,alt,gate:
|
||||
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag), alt.cast(l2i_dt[x.dtype]), gate))),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
|
||||
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),))),
|
||||
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
|
||||
])
|
||||
|
|
@ -546,9 +551,9 @@ pm_float_decomp = PatternMatcher([
|
|||
(UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x:
|
||||
x.replace(dtype=ctx[1].vec(x.dtype.count), src=tuple(s.cast(ctx[1]) if s.dtype == ctx[0] else s for s in x.src))
|
||||
if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), name='st'), lambda ctx,st,idx,val:
|
||||
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val:
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), allow_any_len=True, name='st'), lambda ctx,st,idx,val:
|
||||
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))+st.src[2:]) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), allow_any_len=True, name='st'), lambda ctx,st,idx,val:
|
||||
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -463,8 +463,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
i = (i,)
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, src:UOp|ConstType, **kwargs):
|
||||
return UOp(Ops.STORE, dtypes.void, (self, self.const_like(src) if not isinstance(src, UOp) else src), **kwargs)
|
||||
def store(self, src:UOp|ConstType, gate:UOp|None=None, **kwargs):
|
||||
srcs = (self, self.const_like(src) if not isinstance(src, UOp) else src) + ((gate,) if gate is not None else ())
|
||||
return UOp(Ops.STORE, dtypes.void, srcs, **kwargs)
|
||||
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
|
|
@ -1100,7 +1101,7 @@ class UPat(OpMixin):
|
|||
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
|
||||
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
|
||||
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
|
||||
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
|
||||
assert op is None or isinstance(op, (Ops, tuple, set)), f"op must be Ops or tuple of Ops, not {op!r}"
|
||||
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
||||
self.match_dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
|
|
@ -1543,12 +1544,6 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
|
||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
# vectorized indexes (ie. images) must be int
|
||||
|
|
|
|||
|
|
@ -174,10 +174,12 @@ shared_codegen_spec = PatternMatcher([
|
|||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val)
|
||||
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(), lambda: True), # gated load (alt added in program_spec)
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().load(), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
lambda buf,idx,gate,alt,load: validate_index(buf, idx, gate) if alt.dtype == load.dtype else False),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat()), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
|
||||
# CUSTOM (inline and non inline)
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
|
|
@ -185,9 +187,8 @@ shared_codegen_spec = PatternMatcher([
|
|||
# assembly instruction
|
||||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# INDEX (2-arg and 3-arg with bool gate)
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
# INDEX is just address calculation. OOB validation is on LOAD/STORE where the gate is available.
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}).index(UPat()), lambda: True),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
|
@ -236,9 +237,6 @@ tensor_spec = PatternMatcher([
|
|||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||
|
||||
# END closes ranges
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
|
|
@ -293,7 +291,9 @@ full_spec = PatternMatcher([
|
|||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.STACK), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
||||
# no more bool in index
|
||||
(UPat(Ops.INDEX, name="idx"), lambda idx: not any([dtypes.is_bool(x.dtype) for x in idx.src[1:]])),
|
||||
|
||||
# all loads/stores
|
||||
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue