mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hcq2: rebind var params (#16610)
This commit is contained in:
parent
8efc8d064f
commit
5a9227b30a
1 changed files with 5 additions and 0 deletions
|
|
@ -472,6 +472,11 @@ pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name=
|
|||
|
||||
def parametrize_host_buffers(call:UOp) -> UOp:
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
|
||||
|
||||
# move vars to new slots
|
||||
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
|
||||
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
|
||||
|
||||
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
|
||||
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue