late numbering of var params (#16640)

* do_number_param

* fix sort order in x86

* we don't want this
This commit is contained in:
George Hotz 2026-06-17 00:36:08 -07:00 committed by GitHub
commit be9b570cb2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 22 additions and 8 deletions

View file

@ -46,10 +46,19 @@ pm_remove_vec_dtypes = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
# replace DEFINE_VAR with PARAM
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x:
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=ctx[x.arg[0]], name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp):
if x.arg.slot != -1: return None
ctx[0] += 1
return x.replace(arg=replace(x.arg, slot=ctx[0]-1))
pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param),
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
@ -124,9 +133,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
@ -139,6 +146,10 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# put unnumbered DEFINE_VAR in slots
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
sink = graph_rewrite(sink, pm_number_params, ctx=[num_params], name="number params with -1", walk=True)
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
if SPEC: type_verify(sink, spec_program)

View file

@ -17,8 +17,10 @@ class IselContext:
def __init__(self, sink:UOp):
self.uses = consumer_map_from_toposort(sink.toposort())
self.reg_n = itertools.count()
arg_order = {Ops.PARAM: 0, Ops.SPECIAL: 1}
self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg))
def arg_key(u:UOp):
if u.op is Ops.SPECIAL: return (2, u.arg)
return (0, u.arg.slot) if u.arg.addrspace is not None else (1, u.expr)
self.func_args = sorted([u for u in self.uses if u.op in {Ops.PARAM, Ops.SPECIAL}], key=arg_key)
def vreg(self, cons:tuple[Register, ...]|Register):
return Register(f"v{next(self.reg_n)}", 0, _cons=cons if isinstance(cons, tuple) else (cons,))
@ -39,4 +41,4 @@ class ISARenderer(Renderer):
def copy(self, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
def spill(self, disp:UOp, x:UOp) -> UOp: raise NotImplementedError("arch specific")
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
def asm_str(self, uops:list[UOp], function_name:str) -> str: raise NotImplementedError("arch specific")
def asm_str(self, uops:list[UOp], function_name:str) -> str: raise NotImplementedError("arch specific")

View file

@ -76,6 +76,7 @@ def _mop_index(r:UOp, idx:UOp):
return ret if ret.shape == idx.shape else None
pm_mops = PatternMatcher([
# handle movement ops on INDEX
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), _mop_index),
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),