70 GFLOPS

This commit is contained in:
George Hotz 2025-03-21 20:31:14 +08:00
commit 264dd91b8a

View file

@ -60,10 +60,10 @@ def gep_on_reduce(gep, alu):
tuple(x.gep(gep.arg) if x.op is not Ops.RANGE else x for x in alu.src), alu.arg) if not isinstance(gep.dtype, PtrDType) and \
alu.dtype.count >= gep.dtype.count else None
def multi_add_int2(a0, a1, a2, a3):
arg = "__builtin_HEXAGON_A2_vraddub({0}, {1})"
print(a0.op, a1.op, a2.op, a3.op)
return None
def multi_add_int2(**aa):
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(0) for x in aa.values()))
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(1) for x in aa.values()))
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})")
dsp_pm = PatternMatcher([
# GEP on REDUCE
@ -90,7 +90,9 @@ dsp_pm = PatternMatcher([
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(4), name="r"),
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r)))),
(UPat(dtype=dtypes.int.vec(2), name="a0") + UPat(name="a1") + UPat(name="a2") + UPat(name="a3"), multi_add_int2),
# build __builtin_HEXAGON_A2_vraddub
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
])+gep_pushing
def add_to_mul(c:UOp, x:UOp):
@ -116,6 +118,10 @@ dsp_pm_late = PatternMatcher([
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), name="ld"), prefetch_l1),
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]), arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
#(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"),
# lambda ld, bc: ld.src[0].src[0].cast(bc.dtype.ptr(ld.src[0].dtype.size)).load(dtype=bc.dtype)),
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,