This commit is contained in:
George Hotz 2025-03-22 18:59:54 +08:00
commit fd73ec2b1b
2 changed files with 6 additions and 1 deletions

View file

@ -159,6 +159,11 @@ if __name__ == "__main__":
elif knum == 37:
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 384))
elif knum == 66:
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 0, 8))
#k.apply_opt(Opt(OptOps.PADTO, 0, 32))
#k.apply_opt(Opt(OptOps.UPCAST, 0, 32))
else:
full_shape = k.full_shape
out_shape = k.sts[0].shape

View file

@ -386,7 +386,7 @@ class Kernel:
self.group_for_reduces += 1
elif opt.op is OptOps.UNROLL: # purple
check(axis < self.first_upcast, "can't upcasted already upcasted")
check(amt <= 32, "don't unroll more than 32")
#check(amt <= 32, "don't unroll more than 32")
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)