hcq2: move global sync (#16504)

This commit is contained in:
nimlgen 2026-06-04 17:32:40 +03:00 committed by GitHub
commit 3838c8df1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -285,14 +285,6 @@ def schedule_inner_sync(ctx:DepsCtx, call:UOp) -> UOp:
return op.after(*dps).rtag("deps") if (dps:=dedup(deps)) else op
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.CALL, tag="hcq_cmd", name="call", allow_any_len=True), schedule_inner_sync)])
def add_global_sync(ctx:DepsCtx, submit:UOp, q:UOp) -> UOp|None:
if (dev:=q.arg[0]) in ctx.last_per_queue: return None
# some devices from a command buffer might be used for the first time this schedule, so we wait for their global timeline epoch.
wait = make_signal(dev).wait(make_signal_value(dev).index(UOp.const(dtypes.int, 0)) - 1)
return submit.replace(src=(q.replace(src=(wait, *q.src)),))
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
# *****************
# 3.2. finalizer
@ -328,24 +320,26 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
waits.append(sig.wait(val + opid))
return UOp(Ops.LINEAR, dtypes.void, (*waits, cur), arg=cur.tag[0])
pm_add_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
def add_stores(ctx:set[int], submit:UOp, lin:UOp) -> UOp:
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
new_src = []
for op in lin.src:
for op in q.src:
new_src.append(op.rtag(None) if op.tag else op)
if op.tag and op.tag[1] in ctx:
(devs, queue), opid = op.tag
new_src.append(make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + opid))
return submit.replace(src=(lin.replace(src=tuple(new_src)),))
pm_add_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="lin"),), name="submit"), add_stores)])
return submit.replace(src=(q.replace(src=tuple(new_src)),))
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
# *****************
# 3.4. barriers
def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
if (devs:=q.arg[0]) in ctx: return None
ctx.add(devs)
def add_barriers(submit:UOp, lin:UOp) -> UOp:
return submit.replace(src=(lin.replace(src=(UOp(Ops.BARRIER, dtypes.void), *lin.src)),))
pm_add_barriers = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="lin"),), name="submit"), add_barriers)])
# some devices from a command buffer might be used for the first time this schedule, so we wait for their global timeline epoch.
wait = make_signal(devs).wait(make_signal_value(devs).index(UOp.const(dtypes.int, 0)) - 1)
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
# *****************
# 4.1. encode cmdbufs
@ -450,11 +444,10 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), walk=True, name="schedule inner sync", enter_calls=True)
linear = graph_rewrite(linear, pm_add_global_sync, ctx=deps_ctx, walk=True, name="add global sync", enter_calls=True)
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
linear = graph_rewrite(linear, pm_add_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
linear = graph_rewrite(linear, pm_add_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
linear = graph_rewrite(linear, pm_add_barriers, walk=True, name="add barriers", enter_calls=True)
linear = graph_rewrite(linear, pm_add_inner_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)