hcq2: merge queues (#16514)

* hcq2: mergw queues

* cleaner
This commit is contained in:
nimlgen 2026-06-05 21:20:25 +03:00 committed by GitHub
commit 5ebd44aa12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -332,6 +332,42 @@ def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
return submit.replace(src=(q.replace(src=tuple(new_src)),))
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
# *****************
# 4.1. merge queues
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
return new_sink.substitute({new_submit: new_anchor})
def merge_queues(linear:UOp) -> UOp:
new_src:list[UOp] = []
opened_qs:dict[tuple[tuple[str, ...], str], UOp] = {} # (devs, queue) -> sink, kept in submit order
for call in linear.src:
if call.tag != "hcq":
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in list(opened_qs)] + [call]
continue
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
if (old_sink:=opened_qs.pop((devs, queue), None)) is not None:
new_sink = merge_sinks(old_sink, new_sink) # exact same queue: merge, and re-insert at the end
else:
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
opened_qs[(devs, queue)] = new_sink
return linear.replace(src=tuple(new_src + [sink.call().rtag('hcq') for sink in opened_qs.values()]))
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
# *****************
# 4.2. global sync
def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
if (devs:=q.arg[0]) in ctx: return None
ctx.add(devs)
@ -342,7 +378,7 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
# *****************
# 4.1. encode cmdbufs
# 5.1. encode cmdbufs
@functools.cache
def get_pm_lower(name:str) -> PatternMatcher|None:
@ -447,6 +483,7 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
linear = graph_rewrite(linear, pm_add_inner_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)