mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
correctness
This commit is contained in:
parent
8340d9c1c2
commit
661431ee75
1 changed files with 38 additions and 13 deletions
|
|
@ -106,14 +106,39 @@ def multi_add_int2(**aa):
|
|||
del aa['acc']
|
||||
else:
|
||||
acc = None
|
||||
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(0) for x in aa.values()))
|
||||
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(1) for x in aa.values()))
|
||||
eles = []
|
||||
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(0))
|
||||
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(1))
|
||||
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[0:4]+eles[8:12]))
|
||||
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[4:8]+eles[12:16]))
|
||||
|
||||
"""
|
||||
unsigned long long precast0 = __builtin_HEXAGON_A2_vraddub((*((unsigned long long*)&val1)), (*((unsigned long long*)&val2)));
|
||||
acc4 = (acc4+(*((int2*)&precast0)));
|
||||
|
||||
int2 cast2 = __builtin_convertvector((unsigned_char2){val1[0],val2[0]}, int2);
|
||||
int2 cast4 = __builtin_convertvector((unsigned_char2){val1[1],val2[1]}, int2);
|
||||
int2 cast6 = __builtin_convertvector((unsigned_char2){val1[2],val2[2]}, int2);
|
||||
int2 cast8 = __builtin_convertvector((unsigned_char2){val1[3],val2[3]}, int2);
|
||||
int2 cast10 = __builtin_convertvector((unsigned_char2){val1[4],val2[4]}, int2);
|
||||
int2 cast12 = __builtin_convertvector((unsigned_char2){val1[5],val2[5]}, int2);
|
||||
int2 cast14 = __builtin_convertvector((unsigned_char2){val1[6],val2[6]}, int2);
|
||||
int2 cast16 = __builtin_convertvector((unsigned_char2){val1[7],val2[7]}, int2);
|
||||
acc4 = (acc4+cast2+cast4+cast6+cast8+cast10+cast12+cast14+cast16);
|
||||
"""
|
||||
|
||||
#r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(0) for x in aa.values()))
|
||||
#r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(1) for x in aa.values()))
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
|
||||
#return UOp(Ops.CUSTOMI, dtypes.uint64, (acc.bitcast(dtypes.uint64), r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)),
|
||||
# arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))
|
||||
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.uint64,
|
||||
(r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})").bitcast(dtypes.int.vec(2))
|
||||
#return UOp(Ops.CUSTOMI, dtypes.uint64,
|
||||
# (r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})").bitcast(dtypes.int.vec(2))
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})")
|
||||
|
||||
conv_pm = PatternMatcher([
|
||||
# __builtin_HEXAGON_V6_vrmpybus x3
|
||||
|
|
@ -159,10 +184,10 @@ dsp_pm = PatternMatcher([
|
|||
UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
|
||||
|
||||
# build __builtin_HEXAGON_A2_vraddub
|
||||
#(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
# UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
#(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
# UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
|
||||
# we upcast 3 as 4
|
||||
(UPat(Ops.REDUCE, name="r", src=(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") +
|
||||
|
|
@ -317,9 +342,9 @@ dsp_pm_late = PatternMatcher([
|
|||
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
|
||||
|
||||
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
|
||||
#(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
|
||||
# lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
|
||||
# arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
|
||||
(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
|
||||
lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
|
||||
|
||||
#(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"),
|
||||
# lambda ld, bc: ld.src[0].src[0].cast(bc.dtype.ptr(ld.src[0].dtype.size)).load(dtype=bc.dtype)),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue