mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hotfix bug in get_kernel_actions after TC_SEARCH_OVER_SHAPE was introduced (#8904)
* hotfix search bug * copy actions
This commit is contained in:
parent
15f94ac964
commit
0f6109ec00
1 changed files with 3 additions and 2 deletions
|
|
@ -102,7 +102,8 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
|||
|
||||
# get dictionary of all possible actions
|
||||
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
||||
acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions
|
||||
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
||||
kernel_actions = actions.copy()
|
||||
|
||||
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
|
||||
for i, action in enumerate(kernel_actions):
|
||||
|
|
@ -112,7 +113,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
|||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
if a.axis is not None and a.op is not OptOps.TC:
|
||||
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue
|
||||
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue