test ops passes

This commit is contained in:
George Hotz 2026-06-14 12:58:18 -07:00
commit bdfcb1cb98

View file

@ -52,7 +52,8 @@ pm_remove_vec_dtypes = PatternMatcher([
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
(UPat(GroupOp.ALU, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
# BITCAST?
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
@ -86,8 +87,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# remove all weakints
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
# symbolic
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# symbolic (note: this does POW decomp)
sink = graph_rewrite(sink, sym, name="post index symbolic")
# decompositions
supported_ops = tuple(ren.code_for_op.keys())