Compare commits

...

36 commits

Author SHA1 Message Date
George Hotz
b9be9fbc77
Merge branch 'master' into move_gates_to_load_store 2026-05-06 10:06:31 -07:00
George Hotz
2ccefa11ec does this pass? 2026-05-06 09:30:58 -07:00
George Hotz
95b0a651c2 fix decomp 2026-05-05 15:46:42 -07:00
George Hotz
76606eb386 push 2026-05-05 15:34:05 -07:00
George Hotz
e74bf441f0 Revert "remove legacy stuff"
This reverts commit a0c04a5e35.
2026-05-05 15:33:43 -07:00
George Hotz
661eb76309 fix f2f 2026-05-05 15:24:37 -07:00
George Hotz
a0c04a5e35 remove legacy stuff 2026-05-05 15:14:20 -07:00
George Hotz
13e0fbaba6 fix webgpu and some edge cases 2026-05-05 15:07:28 -07:00
George Hotz
58a09b22ac
Merge branch 'master' into move_gates_to_load_store 2026-05-05 14:58:32 -07:00
George Hotz
d09ea1d620 fix nir 2026-05-05 14:56:34 -07:00
George Hotz
7a00223bd3 Fix webgpu 2026-05-05 14:50:53 -07:00
George Hotz
5053148502 nir fix 2026-05-05 14:48:49 -07:00
George Hotz
ecf49474eb cleanups + fix nir 2026-05-05 14:04:07 -07:00
George Hotz
396d3f441a
Merge branch 'master' into move_gates_to_load_store 2026-05-05 13:55:08 -07:00
George Hotz
6573c103f9 fix wrong load alt dtypes 2026-05-05 13:53:38 -07:00
George Hotz
fc2a289f61 fix nir 2026-05-04 20:27:25 -07:00
George Hotz
5736eee2f2 oops, inverted 2026-05-04 20:22:30 -07:00
George Hotz
651279c7ff
Merge branch 'master' into move_gates_to_load_store 2026-05-04 20:19:00 -07:00
George Hotz
0821bef6b4 fix gated load 2026-05-04 20:17:12 -07:00
George Hotz
437205ae03 flip order, this is simpler 2026-05-04 20:02:41 -07:00
George Hotz
cfefef479b add dtype 2026-05-04 19:56:53 -07:00
George Hotz
5d9431ecb9 fix ptx 2026-05-04 19:44:01 -07:00
George Hotz
0f3b12fcd8 fixes 2026-05-04 19:30:25 -07:00
George Hotz
60c8542320 work 2026-05-04 19:21:22 -07:00
George Hotz
4ec5487ad8 fix renderers 2026-05-04 19:04:49 -07:00
George Hotz
995a787d6c fix llvm crash 2026-05-04 18:52:22 -07:00
George Hotz
1b17762030 fix python 2026-05-04 18:44:53 -07:00
George Hotz
c0f443cf47
Merge branch 'master' into move_gates_to_load_store 2026-05-04 18:35:26 -07:00
George Hotz
ff1258feef fix tests 2026-05-04 17:33:01 -07:00
George Hotz
51b13466dd fix amd 2026-05-04 17:23:43 -07:00
George Hotz
416878db9e i hate ai 2026-05-04 17:14:53 -07:00
George Hotz
e00b3b4065 fix 2026-05-04 17:09:59 -07:00
George Hotz
d810bd2b41
Merge branch 'master' into move_gates_to_load_store 2026-05-04 17:05:21 -07:00
George Hotz
09ec34437d fix oob validation 2026-05-04 16:55:32 -07:00
George Hotz
36383298be move gates to load/store 2026-05-04 14:56:37 -07:00
George Hotz
8f397f5c7c move load gates 2026-05-04 14:45:42 -07:00
12 changed files with 104 additions and 87 deletions

View file

@ -16,6 +16,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcend
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
@ -76,8 +77,13 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
# lower the index dtype to a concrete int
# lower the index dtype to a concrete int. this needs to happen while gates are still present
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
# move the gates from index onto the loads and stores
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# a final symbolic
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# optional pre matcher
@ -106,8 +112,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
])

View file

@ -37,6 +37,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
return drop_stmt
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
start_idx = start_idx.simplify() # if you don't do this, uop_given_valid may simplify things and this might inf loop
idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
@ -281,18 +282,13 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# Where after gated load becomes alt value
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
(UPat.var("gate").where(UPat(Ops.LOAD, src=(UPat(), UPat(), UPat.var("gate")), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
(UPat.var("gate").where(UPat.var("a"), UPat(Ops.LOAD,
src=(UPat(), UPat(), UPat.var("gate", dtype=dtypes.bool).logical_not()), name="l").or_casted()), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
@ -368,7 +364,7 @@ pm_add_loads = PatternMatcher([
# add loads to non ptr index
(UPat(Ops.INDEX, name="idx"), add_load),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
])
# make images

View file

@ -0,0 +1,13 @@
# this transforms Invalid into gated load/stores
from tinygrad.uop.ops import PatternMatcher, UPat
from tinygrad.dtype import Invalid, dtypes
pm_move_gates_from_index = PatternMatcher([
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
# remove hanging weakint casts
(UPat.var("buf").index(UPat.var("idx", dtypes.ints).cast()), lambda buf,idx: buf.index(idx, ptr=True)),
])

View file

@ -44,12 +44,11 @@ base_rewrite = PatternMatcher([
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
# TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
@ -302,12 +301,12 @@ class OpenCLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), allow_any_len=True),
UPat.var("var", dtypes.float.vec(4)))),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
]) + base_rewrite

View file

@ -76,14 +76,14 @@ base_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("mask"))).or_casted("idx"), UPat.var("alt")), allow_any_len=True, name="x"),
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"),
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),

View file

@ -128,8 +128,8 @@ class NIRRenderer(Renderer):
# load/store bool -> uint8
(UPat(Ops.LOAD, dtypes.bool, name="x"),
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x"),
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8)))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
# NIR requires shift amount to be 32 bit: https://docs.mesa3d.org/nir/alu.html#nir-alu-op-ishl
(UPat((Ops.SHL, Ops.SHR), name="x"), lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint))) if x.src[1].dtype.bitsize != 32 else None),
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
@ -146,12 +146,12 @@ class NIRRenderer(Renderer):
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))), UPat.var("val")), allow_any_len=True),
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
@ -268,9 +268,9 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val"))),
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val")), allow_any_len=True),
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("alt"), UPat.var("gate"))),
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
]) + NIRRenderer.def_rewrite

View file

@ -48,10 +48,10 @@ ptx_matcher = PatternMatcher([
# load/store bool -> uint8
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x"),
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8)))),
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
# indexing on PTX is in uint64, we do the math while it's still in the graph
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op", allow_any_len=True), lambda buf,idx,op:
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
# load/store use pointer arithmetic, and the cast does nothing
@ -102,11 +102,11 @@ string_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
# store / gated load / load
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc")), allow_any_len=True), UPat.var("var"))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("var")), allow_any_len=True),
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("alt"), UPat.var("gate")), allow_any_len=True),
lambda ctx, x, loc, alt, gate, buf: flatten([
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]

View file

@ -10,21 +10,20 @@ def sign_extend(val:UOp, sext_am:int):
| val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
# store for char: buf[idx/4] <- (var << (idx%4)*8))
def packed_store(bidx:UOp, var:UOp):
def packed_store(bidx:UOp, var:UOp, gate:UOp|None=None):
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
# preserve valid condition (bidx.src[2]) if it exists for gated stores
idx_src = (bidx.src[0], div_idx) if len(bidx.src) == 2 else (bidx.src[0], div_idx, bidx.src[2])
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, idx_src), dtype=dtypes.uint32)
return UOp.store(UOp(Ops.INDEX, bidx.dtype, idx_src), (buf & wmask) | new_v)
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
buf = UOp.load(idx, *((UOp.const(dtypes.uint32, 0), gate) if gate is not None else ()), dtype=dtypes.uint32)
return UOp.store(idx, (buf & wmask) | new_v, *((gate,) if gate is not None else ()))
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|None=None):
elems, mask = 4//dtype.itemsize, _mask(dtype)
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2]) if var is not None else (bidx.src[0], div_idx))
load = UOp.load(idx, *([var] if var is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
load = UOp.load(idx, *((var, gate) if var is not None and gate is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
@ -41,10 +40,12 @@ wgsl_matcher = PatternMatcher([
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("c"), allow_any_len=True, name="l"),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")),
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
@ -82,14 +83,14 @@ class WGSLRenderer(CStyleLanguage):
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),
lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
]) + base_rewrite

View file

@ -18,8 +18,8 @@ def _load(m, i, dtype: DType):
return from_storage_scalar(m[i], dtype)
def load(inp, j, dtype: DType):
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
if len(inp) >= 3: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x),default,gate in zip(*inp[:3])]
return [_load(m, x+j if x is not None else None, dtype) for m,x in inp[0]]
def _store(m, i, v, dtype: DType):
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
@ -67,8 +67,9 @@ class PythonProgram:
continue
assert dtype is not None, f"{uop} is missing a dtype"
if uop is Ops.STORE:
store_gate = src_values[2] if len(src_values) >= 3 else [True] * warp_size
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
for (m,o,g),v in zip(src_values[0], val):
for (m,o),v,g in zip(src_values[0], val, store_gate):
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
@ -91,6 +92,7 @@ class PythonProgram:
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX:
if len(src_values) != 2: raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
@ -98,7 +100,7 @@ class PythonProgram:
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else:
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
values[i] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
values[i] = src_values[0]
elif uop is Ops.RANGE:

View file

@ -418,13 +418,15 @@ def f2f_clamp(val:UOp, dt:DType) -> UOp:
# FIXME: CMPLT of nan is undefined
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
def f2f_load(x:UOp, fr:DType, to:DType) -> UOp:
if (n:=x.dtype.count) == 1:
return f2f(x.replace(src=(x.src[0],)+((x.src[1].cast(f2f_dt[fr]), x.src[2]) if len(x.src) >= 3 else ()), dtype=f2f_dt[fr]), fr, to)
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),) + \
((x.src[1].gep(i).cast(f2f_dt[fr]), x.src[2]) if len(x.src) >= 3 else ())), fr, to) for i in range(n)))
def f2f_store(st, idx, val, fr:DType, to:DType):
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))) for i in range(n)))
def f2f_store(st:UOp, idx, val, fr:DType, to:DType):
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr))+st.src[2:])
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))+st.src[2:]) for i in range(n)))
# ***** decomposition patterns *****
@ -509,8 +511,8 @@ pm_long_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), allow_any_len=True, name='st'), lambda st,idx,val:
st.replace(src=(reindex(idx, 0), val.rtag(0))+st.src[2:]).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt))),
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
@ -522,7 +524,10 @@ pm_long_decomp = PatternMatcher([
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'), UPat.var('alt'), UPat.var('gate')), name='x'), lambda x,idx,alt,gate:
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag), alt.cast(l2i_dt[x.dtype]), gate))),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),))),
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
])
@ -546,9 +551,9 @@ pm_float_decomp = PatternMatcher([
(UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x:
x.replace(dtype=ctx[1].vec(x.dtype.count), src=tuple(s.cast(ctx[1]) if s.dtype == ctx[0] else s for s in x.src))
if x.dtype.scalar() == ctx[0] else None),
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), name='st'), lambda ctx,st,idx,val:
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val:
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), allow_any_len=True, name='st'), lambda ctx,st,idx,val:
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))+st.src[2:]) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), allow_any_len=True, name='st'), lambda ctx,st,idx,val:
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
])

View file

@ -463,8 +463,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
i = (i,)
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, src:UOp|ConstType, **kwargs):
return UOp(Ops.STORE, dtypes.void, (self, self.const_like(src) if not isinstance(src, UOp) else src), **kwargs)
def store(self, src:UOp|ConstType, gate:UOp|None=None, **kwargs):
srcs = (self, self.const_like(src) if not isinstance(src, UOp) else src) + ((gate,) if gate is not None else ())
return UOp(Ops.STORE, dtypes.void, srcs, **kwargs)
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
@ -1100,7 +1101,7 @@ class UPat(OpMixin):
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
assert op is None or isinstance(op, (Ops, tuple, set)), f"op must be Ops or tuple of Ops, not {op!r}"
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
self.match_dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
@ -1543,12 +1544,6 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
# vectorized indexes (ie. images) must be int

View file

@ -174,10 +174,12 @@ shared_codegen_spec = PatternMatcher([
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# LOAD(idx) / STORE(idx, val)
(UPat().index(UPat()).or_casted().load(), lambda: True),
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(), lambda: True), # gated load (alt added in program_spec)
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().load(), validate_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
lambda buf,idx,gate,alt,load: validate_index(buf, idx, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat()), validate_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
# CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
@ -185,9 +187,8 @@ shared_codegen_spec = PatternMatcher([
# assembly instruction
(UPat(Ops.INS), lambda: True),
# INDEX (2-arg and 3-arg with bool gate)
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
# INDEX is just address calculation. OOB validation is on LOAD/STORE where the gate is available.
(UPat(GroupOp.Defines|{Ops.AFTER}).index(UPat()), lambda: True),
# SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
@ -236,9 +237,6 @@ tensor_spec = PatternMatcher([
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
# END closes ranges
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
@ -293,7 +291,9 @@ full_spec = PatternMatcher([
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.STACK), lambda: True),
(UPat(Ops.INDEX), lambda: True),
# no more bool in index
(UPat(Ops.INDEX, name="idx"), lambda idx: not any([dtypes.is_bool(x.dtype) for x in idx.src[1:]])),
# all loads/stores
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),