hcq2: fix repeated calls (#16552)

This commit is contained in:
nimlgen 2026-06-09 19:11:42 +03:00 committed by GitHub
commit 2ab2d51099
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 37 additions and 33 deletions

View file

@ -226,19 +226,17 @@ pm_prep_runtime = PatternMatcher([
def make_submit(*cmds, devs:str|tuple[str, ...], queue:str) -> UOp:
devs:tuple[str, ...] = to_tuple(devs)
cmds = tuple([cmd.replace(arg=(devs, queue)).rtag("hcq_cmd") if cmd.op is Ops.CALL else cmd for cmd in cmds])
queue = UOp(Ops.LINEAR, dtypes.void, src=cmds, arg=(devs, queue))
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(queue,), arg="submit")
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, dtypes.void, src=tuple(cmds), arg=(devs, queue)),), arg="submit")
def lower_program(call:UOp, prg:UOp) -> UOp:
return make_submit(prg.call(*call.src[1:]), devs=call.src[1].device, queue="COMPUTE:0").sink().call().rtag("hcq")
return make_submit(prg, devs=call.src[1].device, queue="COMPUTE:0").sink().call(*call.src[1:]).rtag(("hcq", tuple(prg.arg[1].outs)))
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
dst, src = call.src[1], call.src[2]
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
cp_op = UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes)
return make_submit(cp_op.call(*call.src[1:]), devs=hcq_dev, queue="COPY:0").sink().call().rtag("hcq")
return make_submit(cp_op, devs=hcq_dev, queue="COPY:0").sink().call(*call.src[1:]).rtag(("hcq", (0,)))
pm_lower_ops = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.BUFFER).or_after(), UPat(Ops.BUFFER).or_after()), name="prg"),),
@ -266,35 +264,42 @@ class DepsCtx:
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
def schedule_inner_sync(ctx:DepsCtx, call:UOp) -> UOp:
refs = [b.buffer for b in get_call_arg_uops(call)]
write_bufs = ast.arg[1].outs if (ast:=call.src[0]).op is Ops.PROGRAM else (0,)
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src = []
for call in linear.src:
if not isinstance(call.tag, tuple) or not call.tag[0] == "hcq":
new_src.append(call)
continue
# tag carries (queue arg, opid)
ctx.last_per_queue[call.arg[0]] = (op:=call.src[0].rtag((call.arg, next(ctx.opid))))
q = get_submit(call.src[0]).src[0]
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
deps = []
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], write_bufs, op)
return op.after(*dps).rtag("deps") if (dps:=dedup(deps)) else op
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.CALL, tag="hcq_cmd", name="call", allow_any_len=True), schedule_inner_sync)])
deps = []
refs = [b.buffer for b in get_call_arg_uops(call)]
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.tag[1], new_q)
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:]), tag="hcq"))
return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
# *****************
# 3.2. finalizer
def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
devs = tuple(dedup([d for q in queues for d in to_tuple(q.tag[0][0])]))
devs = tuple(dedup([d for q in queues for d in to_tuple(q.arg[0])]))
zero = UOp.const(dtypes.int, 0)
tl = make_signal_value(devs)
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.tag[0][1] for q in queues])]
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]).sink().call().rtag("hcq")
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
parts:dict[str, list[UOp]] = collections.defaultdict(list)
for d, q in ctx.last_per_queue.items(): parts[d[0].split(':')[0]].append(q)
for d, q in ctx.last_per_queue.items(): parts[to_tuple(d[0])[0].split(':')[0]].append(q)
nbump = next(ctx.opid)
return linear.replace(src=linear.src + tuple([make_finalizer(queues, nbump) for queues in parts.values()]))
@ -304,26 +309,25 @@ pm_add_finalizer = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), add_finaliz
# 3.3. lower loads/stores
def add_loads(ctx:set[int], deps:UOp) -> UOp:
cur_devs = to_tuple((cur:=deps.src[0]).tag[0][0])
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
waits = []
for (devs, queue), opid in [dq.tag for dq in deps.src[1:]]:
ctx.add(opid) # mark op to update signal.
for dep in deps.src[1:]:
devs, queue = dep.arg
ctx.add(dep.tag) # mark op to update signal.
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
waits.append(sig.wait(val + opid))
return UOp(Ops.LINEAR, dtypes.void, (*waits, cur), arg=cur.tag[0])
waits.append(sig.wait(val + dep.tag))
return cur.replace(src=tuple(waits) + cur.src)
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
new_src = []
for op in q.src:
new_src.append(op.rtag(None) if op.tag else op)
if op.tag and op.tag[1] in ctx:
(devs, queue), opid = op.tag
new_src.append(make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + opid))
return submit.replace(src=(q.replace(src=tuple(new_src)),))
src = q.src
if q.tag in ctx:
devs, queue = q.arg
src += (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
return submit.replace(src=(q.replace(src=src, tag=None),))
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
# *****************
@ -473,9 +477,9 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), walk=True, name="schedule inner sync", enter_calls=True)
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), name="schedule inner sync")
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
linear = graph_rewrite(linear, pm_add_inner_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
linear = graph_rewrite(linear, pm_add_inner_loads, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)

View file

@ -272,7 +272,7 @@ class PCIIfaceBase:
return HCQBuffer(mapping.va_addr, size, view=barview, meta=PCIAllocationMeta(mapping, cpu_access, hMemory=mapping.paddrs[0][0]), owner=self.dev)
def free(self, b:HCQBuffer):
if b.owner != self.dev: self.dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
if b.owner != self.dev: self.dev.iface.dev_impl.mm.unmap_range(b.va_addr, round_up(b.size, 0x1000))
if b.owner == self.dev and b.meta.mapping.aspace is AddrSpace.PHYS: self.dev_impl.mm.vfree(b.meta.mapping)
if b.owner == self.dev and self.is_local() and b.meta.has_cpu_mapping: FileIOInterface.munmap(b.va_addr, b.size)