This commit is contained in:
George Hotz 2025-03-25 13:41:34 +08:00
commit ccd18a803c
3 changed files with 5 additions and 2 deletions

View file

@ -132,6 +132,7 @@ if __name__ == "__main__":
#k.apply_opt(Opt(OptOps.PADTO, 4, 4))
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
elif knum == 5:
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
k.apply_opt(Opt(OptOps.UPCAST, 2, 0))

View file

@ -196,7 +196,8 @@ gep_pushing = PatternMatcher([
symbolic = symbolic_simple+PatternMatcher([
# ** COMMUTATIVE flipping (only for ints) **
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# NOTE: this can break merging vector math by only flipping some of them
#(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms **

View file

@ -163,7 +163,8 @@ def prefetch_l1(ld:UOp):
def vectorize_shuffle(x:UOp):
if not all(s.op in {Ops.GEP, Ops.CONST} for s in x.src): return None
gepped = dedup([s.src[0] for s in x.src if s.op is Ops.GEP])
if len(gepped) < 2: return None
if len(gepped) != 3: return None
if not all(x.dtype.scalar() is dtypes.uchar and x.dtype.count == 128 for x in gepped): return None
arg = []
for s in x.src:
if s.op is Ops.GEP: