This commit is contained in:
George Hotz 2025-03-20 23:27:06 +08:00
commit 61c02ca634

View file

@ -36,13 +36,13 @@ def multi_mul(a0, a1, b0, b1, c0, c1, d0, d1, acc=None):
if simp_m1.op is Ops.GEP and simp_m1.arg == simp_m1.arg[0:4]*32:
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4]).bitcast(dtypes.uint)
if acc is not None:
return UOp(Ops.CUSTOM, dtypes.int.vec(32), (acc, m0, scalar_m1), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOM, dtypes.int.vec(32), (m0, scalar_m1), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
if acc is not None:
return UOp(Ops.CUSTOM, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOM, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_128B({0}, {1})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_128B({0}, {1})")
dsp_pm = gep_pushing+PatternMatcher([
# GEP on REDUCE
@ -75,17 +75,17 @@ def add_to_mul(c:UOp, x:UOp):
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpybusv_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
elif 'acc' in c.arg and x.op is not Ops.CUSTOM:
return c.src[0] + c.replace(src=(x, c.src[1], c.src[2]))
return c.replace(src=(x+c.src[0], c.src[1], c.src[2]))
else:
return None
dsp_pm_late = PatternMatcher([
# prefetch L1
(UPat(Ops.LOAD, dtype=dtypes.uchar.vec(8), name="ld"),
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].src[0].index(ld.src[0].src[0].src[1]+8),),
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].src[0].index(ld.src[0].src[0].src[1]+16),),
arg="__builtin_HEXAGON_Y2_dcfetch({0});"),)) if ld.src[-1].op is not Ops.CUSTOM else None),
(UPat(Ops.CUSTOM, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
#(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"),
# lambda ld, bc: ld.src[0].src[0].cast(bc.dtype.ptr(ld.src[0].dtype.size)).load(dtype=bc.dtype)),