hcq2: optims (#16624)

* hcq2: optims

* x
This commit is contained in:
nimlgen 2026-06-16 03:58:28 +07:00 committed by GitHub
commit 4a0488ae97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,7 +27,7 @@ class HCQ2Compiled(Compiled):
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
])
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
@ -294,7 +294,12 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
refs = get_call_arg_uops(call)
deps = dedup(flatten(ctx.deps.access_resources([get_dep_buf(ctx, b, l) for b in refs], call.arg.aux.outs, new_q) for l in range(len(q.arg[0]))))
new_q = new_q.after(*deps).rtag("deps") if deps else new_q
# optims: keep only the max wait per queue, and drop self-queue waits when the queue self-orders
deps = {dep.arg:dep for dep in sorted(deps, key=lambda x: x.tag)}
if to_tuple(new_q.arg[0])[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
deps.pop(new_q.arg, None)
new_q = new_q.after(*deps.values()).rtag("deps") if deps else new_q
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -353,34 +358,35 @@ pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", s
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
return new_sink.substitute({new_submit: new_anchor})
def merge_sink(sinks:list[UOp]) -> UOp:
if len(sinks) == 1: return sinks[0]
submits = [get_submit(sink) for sink in sinks]
queues = [submit.src[0] for submit in submits]
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
for sink, submit in zip(sinks[:-1], submits[:-1]):
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
def merge_queues(linear:UOp) -> UOp:
new_src:list[UOp] = []
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
for call in linear.src:
if call.tag != "hcq":
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
continue
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
aux = call.arg.aux
new_rec = ([new_sink], call.arg.aux)
if (old:=opened_qs.pop((devs, queue), None)) is not None:
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
else:
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
opened_qs[(devs, queue)] = (new_sink, aux)
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
opened_qs[(devs, queue)] = new_rec
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
# *****************