This commit is contained in:
George Hotz 2025-03-21 13:04:18 +08:00
commit ff3438be4e
3 changed files with 70 additions and 10 deletions

View file

@ -87,7 +87,7 @@ if __name__ == "__main__":
pass
elif knum == 20:
# 784x192 * 192x32 -> 784x32
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
elif knum == 35:

View file

@ -297,11 +297,66 @@ class TestDSPCache(unittest.TestCase):
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
src = prg.src
src = src.replace("128))) data3) {\n", "128))) data3) {\n __builtin_HEXAGON_Y4_l2fetch(data2, 6144);\n")
src = src.replace("int32 acc3 = cast0;\n", "int32 acc3 = cast0;\n __builtin_HEXAGON_Y4_l2fetch(data1+(ridx0*768), 768);\n #pragma nounroll\n")
print(src)
prg = replace(prg, src=src)
print(prg.src)
new_src = """
typedef int int32 __attribute__((aligned(128),vector_size(128)));
typedef signed char signed_char128 __attribute__((aligned(128),vector_size(128)));
typedef unsigned char unsigned_char8 __attribute__((aligned(8),vector_size(8)));
typedef unsigned char unsigned_char4 __attribute__((aligned(4),vector_size(4)));
typedef unsigned char unsigned_char128 __attribute__((aligned(128),vector_size(128)));
__attribute__((noinline)) void r_196_24_8_32_4(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value(
128))) data2, int* restrict __attribute__((align_value(128))) data3) {
int32 cast0 = (int32){0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
int32 val0 = *((int32*)((data3+0)));
for (int ridx0 = 0; ridx0 < 196; ridx0++) {
int32 acc0 = cast0;
int32 acc1 = cast0;
int32 acc2 = cast0;
int32 acc3 = cast0;
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768);
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+192);
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+384);
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+576);
for (int ridx1 = 0; ridx1 < 24; ridx1++) {
signed_char128 val1 = *((signed_char128*)((data2+(ridx1<<8))));
signed_char128 val2 = *((signed_char128*)((data2+((1+(ridx1<<1))<<7))));
int alu0 = ((ridx0*768)+(ridx1<<3));
unsigned_char8 val3 = *((unsigned_char8*)((data1+alu0)));
__builtin_HEXAGON_Y2_dcfetch(((data1+alu0)+16));
unsigned_char8 val4 = *((unsigned_char8*)((data1+(alu0+192))));
__builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+192))+16));
unsigned_char8 val5 = *((unsigned_char8*)((data1+(alu0+384))));
__builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+384))+16));
unsigned_char8 val6 = *((unsigned_char8*)((data1+(alu0+576))));
__builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+576))+16));
unsigned_char4 alu5 = __builtin_shufflevector(val3, val3, 0, 1, 2, 3);
unsigned_char4 alu6 = __builtin_shufflevector(val4, val4, 0, 1, 2, 3);
unsigned_char4 alu7 = __builtin_shufflevector(val5, val5, 0, 1, 2, 3);
unsigned_char4 alu8 = __builtin_shufflevector(val6, val6, 0, 1, 2, 3);
acc0 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc0, val1, (*((unsigned int*)&alu5)));
acc1 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc1, val1, (*((unsigned int*)&alu6)));
acc2 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc2, val1, (*((unsigned int*)&alu7)));
acc3 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc3, val1, (*((unsigned int*)&alu8)));
unsigned_char4 alu9 = __builtin_shufflevector(val3, val3, 4, 5, 6, 7);
unsigned_char4 alu10 = __builtin_shufflevector(val4, val4, 4, 5, 6, 7);
unsigned_char4 alu11 = __builtin_shufflevector(val5, val5, 4, 5, 6, 7);
unsigned_char4 alu12 = __builtin_shufflevector(val6, val6, 4, 5, 6, 7);
acc0 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc0, val2, (*((unsigned int*)&alu9)));
acc1 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc1, val2, (*((unsigned int*)&alu10)));
acc2 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc2, val2, (*((unsigned int*)&alu11)));
acc3 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc3, val2, (*((unsigned int*)&alu12)));
}
unsigned_char128 alu18 = __builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B((((((acc3+val0)*203)+32767)/65536)+136), (((((acc2+val0)*203)+32767)/65536)+136)), __builtin_HEXAGON_V6_vpackwh_sat_128B((((((acc1+val0)*203)+32767)/65536)+136), (((((acc0+val0)*203)+32767)/65536)+136)));
*((unsigned_char128*)((data0+(ridx0<<7)))) = alu18;
}
}
"""
prg = replace(prg, src=new_src+prg.src.split("/* DSP boilerplate */ ")[1])
rt = CompiledRunner(prg)
Device.default.compiler.disassemble(rt.lib)
ei = ExecItem(rt, bufs_from_lin(k))

View file

@ -79,12 +79,17 @@ def add_to_mul(c:UOp, x:UOp):
else:
return None
def prefetch_l1(ld:UOp):
if ld.src[-1].op is Ops.CUSTOM: return None
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
first = ld.src[0].src[0].substitute({ranges[-1]: ranges[-1].src[0]})
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].index(UOp.const(dtypes.int, ld.dtype.count*2)),), arg="__builtin_HEXAGON_Y2_dcfetch({0});")
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(first,), arg="__builtin_HEXAGON_Y2_dcfetch({0});")
return ld.replace(src=ld.src+(x1,x2))
dsp_pm_late = PatternMatcher([
# prefetch L1
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), name="ld"),
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].index(UOp.const(dtypes.int, ld.dtype.count*2)),),
arg="__builtin_HEXAGON_Y2_dcfetch({0});"),)) if ld.src[-1].op is not Ops.CUSTOM else None),
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), name="ld"), prefetch_l1),
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
#(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"),