hcq2: move pre bufferize (#16589)

* hcq2: move pre bufferize

* x
This commit is contained in:
nimlgen 2026-06-12 16:11:59 +03:00 committed by GitHub
commit 2bfdf85f87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -410,6 +410,14 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, tag={"compute", "copy"}),), allow_any_len=True, name="cmdbuf"), lift_patches_to_cmdbuf),
])
# *****************
# 5.3. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
def hold_call_buffers(call:UOp) -> UOp|None:
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
# *****************
# 6. bufferize placeholders: replace placeholders with real buffers.
@ -420,15 +428,7 @@ def bufferize_buf(buf:UOp) -> UOp|None:
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
# *****************
# 7.1. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
def hold_call_buffers(call:UOp) -> UOp|None:
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
# *****************
# 7.2. resolve patches
# 7. resolve patches
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
@ -494,10 +494,10 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
# realize starts from here
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")