small heuristic cleanup [pr] (#12892)

This commit is contained in:
chenyu 2025-10-23 10:50:15 -04:00 committed by GitHub
commit 6e4ee8deea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,15 +27,15 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis
if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)):
good_tc_opt = False
tk = k.copy()
try: # check TC first and apply hand-coded opts if successful
tk = k.copy()
rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value)))
good_tc_opt = True
except KernelOptError:
pass
if good_tc_opt:
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
if rngs is not None and not AMX:
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
if good_tc_opt and not AMX:
if rngs is not None:
for tc_dim in [1,0]: # attempt to upcast M and N
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
if szs:
@ -149,7 +149,6 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# if nothing at all is upcasted and it's easy to, do an upcast
for splits in [4]:
# TODO: somehow this never hits a reduce
if not k.upcasted and k.upcastable_dims and k.full_shape[k.upcastable_dims[-1]] % splits == 0:
k.apply_opt(Opt(OptOps.UPCAST, k.upcastable_dims[-1], splits))