This commit is contained in:
George Hotz 2025-03-21 20:36:18 +08:00
commit 8a477ba4e1

View file

@ -86,9 +86,12 @@ dsp_pm = PatternMatcher([
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
# REDUCE int4 -> 2xint2
# REDUCE int4 -> 2xint2, int8 -> 4xint2
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(4), name="r"),
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r)))),
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(8), name="r"),
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r),
gep_on_reduce(r.gep((4,5)), r), gep_on_reduce(r.gep((6,7)), r)))),
# build __builtin_HEXAGON_A2_vraddub
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \