mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
late numbering of var params (#16640)
* do_number_param * fix sort order in x86 * we don't want this
This commit is contained in:
parent
c7055d658f
commit
be9b570cb2
3 changed files with 22 additions and 8 deletions
|
|
@ -46,10 +46,19 @@ pm_remove_vec_dtypes = PatternMatcher([
|
|||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
|
||||
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
|
||||
# replace DEFINE_VAR with PARAM
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x:
|
||||
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=ctx[x.arg[0]], name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
|
||||
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def do_number_param(ctx:list[int], x:UOp):
|
||||
if x.arg.slot != -1: return None
|
||||
ctx[0] += 1
|
||||
return x.replace(arg=replace(x.arg, slot=ctx[0]-1))
|
||||
|
||||
pm_number_params = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), do_number_param),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
|
@ -124,9 +133,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
|
||||
# this is new style
|
||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
|
||||
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
||||
|
||||
# move gates from unrenderable INVALID where
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||
|
|
@ -139,6 +146,10 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
# put unnumbered DEFINE_VAR in slots
|
||||
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
|
||||
sink = graph_rewrite(sink, pm_number_params, ctx=[num_params], name="number params with -1", walk=True)
|
||||
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
|
||||
if SPEC: type_verify(sink, spec_program)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,10 @@ class IselContext:
|
|||
def __init__(self, sink:UOp):
|
||||
self.uses = consumer_map_from_toposort(sink.toposort())
|
||||
self.reg_n = itertools.count()
|
||||
arg_order = {Ops.PARAM: 0, Ops.SPECIAL: 1}
|
||||
self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg))
|
||||
def arg_key(u:UOp):
|
||||
if u.op is Ops.SPECIAL: return (2, u.arg)
|
||||
return (0, u.arg.slot) if u.arg.addrspace is not None else (1, u.expr)
|
||||
self.func_args = sorted([u for u in self.uses if u.op in {Ops.PARAM, Ops.SPECIAL}], key=arg_key)
|
||||
|
||||
def vreg(self, cons:tuple[Register, ...]|Register):
|
||||
return Register(f"v{next(self.reg_n)}", 0, _cons=cons if isinstance(cons, tuple) else (cons,))
|
||||
|
|
@ -39,4 +41,4 @@ class ISARenderer(Renderer):
|
|||
def copy(self, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
|
||||
def spill(self, disp:UOp, x:UOp) -> UOp: raise NotImplementedError("arch specific")
|
||||
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
|
||||
def asm_str(self, uops:list[UOp], function_name:str) -> str: raise NotImplementedError("arch specific")
|
||||
def asm_str(self, uops:list[UOp], function_name:str) -> str: raise NotImplementedError("arch specific")
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ def _mop_index(r:UOp, idx:UOp):
|
|||
return ret if ret.shape == idx.shape else None
|
||||
|
||||
pm_mops = PatternMatcher([
|
||||
# handle movement ops on INDEX
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), _mop_index),
|
||||
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue