This commit is contained in:
George Hotz 2025-03-28 16:49:36 +08:00
commit 9bbd12dc65
3 changed files with 7 additions and 4 deletions

View file

@ -264,6 +264,8 @@ load_store_indexing = PatternMatcher([
(UPat(Ops.AND, name="valid"), simplify_valid),
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# index True is just Index
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
# delete_redundant_gates (after expand)
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val"))), delete_redundant_gates),
@ -278,7 +280,8 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), name="x"), lambda x: x.replace(src=x.src+(x.const_like(0),))),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
# gate any stores that aren't gated with ifs
#(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
# lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),

View file

@ -45,7 +45,7 @@ base_rewrite = PatternMatcher([
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted('bidx'), UPat.var("var"))),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted('bidx'), UPat.var("var")), allow_any_len=True),
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),

View file

@ -188,8 +188,8 @@ def add_to_mul(c:UOp, x:UOp):
def prefetch_l1(ld:UOp, idx:UOp):
if ld.src[-1].op is Ops.CUSTOM: return None
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.index(UOp.const(dtypes.int, ld.dtype.count*2)),), arg="__builtin_HEXAGON_Y2_dcfetch({0});")
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.substitute({ranges[-1]: ranges[-1].src[0]}),), arg="__builtin_HEXAGON_Y2_dcfetch({0});")
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1]+UOp.const(dtypes.int, ld.dtype.count*2),), arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1].substitute({ranges[-1]: ranges[-1].src[0]}),), arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
return ld.replace(src=ld.src+(x1, x2))
def vectorize_shuffle(vec:UOp):