This commit is contained in:
George Hotz 2025-04-02 00:13:38 +08:00
commit e18cdbcbe2
2 changed files with 10 additions and 5 deletions

View file

@ -442,17 +442,19 @@ class Kernel:
# special path for DSP
if k.full_shape[-3:] == (32,3,3):
# 3x3 dwconv
if k.full_shape[-4]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, len(k.full_shape)-4, 4))
# kernel 49 is broken
if k.full_shape[-4]%4 != 0 and k.full_shape[-4] != 7: k.apply_opt(Opt(OptOps.PADTO, len(k.full_shape)-4, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-3, 32))
if k.full_shape[len(k.full_shape)-4]%4 == 0:
if k.full_shape[len(k.full_shape)-4] <= 8: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 0))
else: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
#if k.full_shape[len(k.full_shape)-4] <= 8: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 0))
#else: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
# if this is small, swap it
# NOTE: this is breaking something (should be fixed w/o padto)
# kernel 23 is broken with this
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.SWAP, 0, 1))
#if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.SWAP, 0, 1))
elif k.full_shape[-4:] == (32,3,3,3):
# 3x3 normal conv
k.apply_opt(Opt(OptOps.UNROLL, 2, 0))
@ -476,7 +478,7 @@ class Kernel:
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
if k.full_shape[1]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
# if the more is small, upcast it (kernel 50 is broken with this)
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.UPCAST, 0, 0))
#if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.UPCAST, 0, 0))
elif len(k.full_shape) == 2 and k.first_reduce == 1:
# unroll to 4 if we can
if k.full_shape[k.first_reduce]%4 == 0: k.apply_opt(Opt(OptOps.UNROLL, 0, 4))

View file

@ -257,6 +257,9 @@ def vectorize_shuffle(vec:UOp):
gepped = dedup([s.src[0] for s in vec.src if s.op is Ops.GEP])
if len(gepped) == 0: return None
if len(gepped) == 1:
# this pattern is broken in DSP clang
if gepped[0].dtype.count == 4: return None
#return None
arg = []
for s in vec.src:
if s.op is Ops.GEP: