mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
41aa2fe119
commit
4a0488ae97
1 changed files with 23 additions and 17 deletions
|
|
@ -27,7 +27,7 @@ class HCQ2Compiled(Compiled):
|
|||
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
|
||||
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
])
|
||||
|
||||
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
|
||||
|
|
@ -294,7 +294,12 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
refs = get_call_arg_uops(call)
|
||||
deps = dedup(flatten(ctx.deps.access_resources([get_dep_buf(ctx, b, l) for b in refs], call.arg.aux.outs, new_q) for l in range(len(q.arg[0]))))
|
||||
|
||||
new_q = new_q.after(*deps).rtag("deps") if deps else new_q
|
||||
# optims: keep only the max wait per queue, and drop self-queue waits when the queue self-orders
|
||||
deps = {dep.arg:dep for dep in sorted(deps, key=lambda x: x.tag)}
|
||||
if to_tuple(new_q.arg[0])[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
deps.pop(new_q.arg, None)
|
||||
|
||||
new_q = new_q.after(*deps.values()).rtag("deps") if deps else new_q
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
|
@ -353,34 +358,35 @@ pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", s
|
|||
|
||||
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
|
||||
|
||||
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
|
||||
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
|
||||
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
|
||||
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
|
||||
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
|
||||
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
|
||||
return new_sink.substitute({new_submit: new_anchor})
|
||||
def merge_sink(sinks:list[UOp]) -> UOp:
|
||||
if len(sinks) == 1: return sinks[0]
|
||||
submits = [get_submit(sink) for sink in sinks]
|
||||
queues = [submit.src[0] for submit in submits]
|
||||
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
|
||||
for sink, submit in zip(sinks[:-1], submits[:-1]):
|
||||
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
|
||||
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
|
||||
|
||||
def merge_queues(linear:UOp) -> UOp:
|
||||
new_src:list[UOp] = []
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
|
||||
aux = call.arg.aux
|
||||
new_rec = ([new_sink], call.arg.aux)
|
||||
if (old:=opened_qs.pop((devs, queue), None)) is not None:
|
||||
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
|
||||
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
|
||||
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
|
||||
else:
|
||||
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
|
||||
opened_qs[(devs, queue)] = (new_sink, aux)
|
||||
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
|
||||
opened_qs[(devs, queue)] = new_rec
|
||||
|
||||
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
|
||||
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
|
||||
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
|
||||
|
||||
# *****************
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue