full_shape

This commit is contained in:
George Hotz 2025-03-27 16:39:15 +08:00
commit a8bd26d9bc

View file

@ -44,13 +44,13 @@ if __name__ == "__main__":
#elif k.full_shap[-4] == 7: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 7))
#elif k.full_shape[-4] == 14: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 2))
elif len(k.full_shape) == 3 and k.full_shape[1] == 32:
#if k.full_shape[0]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 4))
if k.full_shape[0]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 4))
# weight without more
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
if k.full_shape[0]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
elif len(k.full_shape) == 4 and k.full_shape[2] == 32:
#if k.full_shape[1]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 1, 4))
if k.full_shape[1]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 1, 4))
# weight with more
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))