correctness

This commit is contained in:
George Hotz 2025-04-01 17:01:46 +08:00
commit 661431ee75

View file

@ -106,14 +106,39 @@ def multi_add_int2(**aa):
del aa['acc']
else:
acc = None
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(0) for x in aa.values()))
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(1) for x in aa.values()))
eles = []
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(0))
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(1))
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[0:4]+eles[8:12]))
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[4:8]+eles[12:16]))
"""
unsigned long long precast0 = __builtin_HEXAGON_A2_vraddub((*((unsigned long long*)&val1)), (*((unsigned long long*)&val2)));
acc4 = (acc4+(*((int2*)&precast0)));
int2 cast2 = __builtin_convertvector((unsigned_char2){val1[0],val2[0]}, int2);
int2 cast4 = __builtin_convertvector((unsigned_char2){val1[1],val2[1]}, int2);
int2 cast6 = __builtin_convertvector((unsigned_char2){val1[2],val2[2]}, int2);
int2 cast8 = __builtin_convertvector((unsigned_char2){val1[3],val2[3]}, int2);
int2 cast10 = __builtin_convertvector((unsigned_char2){val1[4],val2[4]}, int2);
int2 cast12 = __builtin_convertvector((unsigned_char2){val1[5],val2[5]}, int2);
int2 cast14 = __builtin_convertvector((unsigned_char2){val1[6],val2[6]}, int2);
int2 cast16 = __builtin_convertvector((unsigned_char2){val1[7],val2[7]}, int2);
acc4 = (acc4+cast2+cast4+cast6+cast8+cast10+cast12+cast14+cast16);
"""
#r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(0) for x in aa.values()))
#r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(1) for x in aa.values()))
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)),
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
#return UOp(Ops.CUSTOMI, dtypes.uint64, (acc.bitcast(dtypes.uint64), r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)),
# arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)),
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.uint64,
(r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})").bitcast(dtypes.int.vec(2))
#return UOp(Ops.CUSTOMI, dtypes.uint64,
# (r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})").bitcast(dtypes.int.vec(2))
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})")
conv_pm = PatternMatcher([
# __builtin_HEXAGON_V6_vrmpybus x3
@ -159,10 +184,10 @@ dsp_pm = PatternMatcher([
UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
# build __builtin_HEXAGON_A2_vraddub
#(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
# UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
#(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
# UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
# we upcast 3 as 4
(UPat(Ops.REDUCE, name="r", src=(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") +
@ -317,9 +342,9 @@ dsp_pm_late = PatternMatcher([
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
#(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
# lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
# arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
#(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"),
# lambda ld, bc: ld.src[0].src[0].cast(bc.dtype.ptr(ld.src[0].dtype.size)).load(dtype=bc.dtype)),