mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hcq2: fix multi (#16661)
This commit is contained in:
parent
5989d0b150
commit
eda0a402d1
3 changed files with 56 additions and 33 deletions
|
|
@ -7,7 +7,7 @@ from tinygrad.helpers import to_tuple, round_up
|
|||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
|
|
@ -288,18 +288,25 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
new_src.append(call)
|
||||
continue
|
||||
|
||||
q = get_submit(call.src[0]).src[0]
|
||||
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
|
||||
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
|
||||
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
|
||||
|
||||
refs = get_call_arg_uops(call)
|
||||
deps = dedup(flatten(ctx.deps.access_resources([get_dep_buf(ctx, b, l) for b in refs], call.arg.aux.outs, new_q) for l in range(len(q.arg[0]))))
|
||||
# per-lane deps, tracked per (device, queue). skip self
|
||||
dep_lanes:list[tuple[UOp, int]] = []
|
||||
for lane, d in enumerate(qdevs):
|
||||
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
|
||||
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
|
||||
|
||||
# optims: keep only the max wait per queue, and drop self-queue waits when the queue self-orders
|
||||
deps = {dep.arg:dep for dep in sorted(deps, key=lambda x: x.tag)}
|
||||
if to_tuple(new_q.arg[0])[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
deps.pop(new_q.arg, None)
|
||||
# drop self-queue waits, queue self-orders
|
||||
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
|
||||
|
||||
new_q = new_q.after(*deps.values()).rtag("deps") if deps else new_q
|
||||
# keep latest dep per lane, group lanes
|
||||
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
|
||||
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
|
||||
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
||||
|
||||
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
|
@ -314,7 +321,10 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|||
|
||||
# queue is inc with deps
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
|
||||
submit = submit.replace(src=(submit.src[0].after(*queues).rtag("deps"),))
|
||||
|
||||
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
|
||||
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
|
||||
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
|
||||
|
||||
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
||||
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
|
||||
|
|
@ -335,12 +345,14 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|||
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
|
||||
|
||||
waits = []
|
||||
for dep in deps.src[1:]:
|
||||
devs, queue = dep.arg
|
||||
for lanes, dep in zip(deps.arg, deps.src[1:]):
|
||||
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
|
||||
ctx.add(dep.tag) # mark op to update signal.
|
||||
|
||||
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
|
||||
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
|
||||
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
|
||||
lanes = set(lanes)
|
||||
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
|
||||
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
|
||||
waits.append(sig.wait(val + dep.tag))
|
||||
return cur.replace(src=tuple(waits) + cur.src)
|
||||
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
|
|
@ -371,7 +383,8 @@ def merge_queues(linear:UOp) -> UOp:
|
|||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
|
||||
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
|
|
@ -410,7 +423,7 @@ pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
|||
# 4.4. replace params with per-submit input address loads
|
||||
|
||||
def replace_params(call:UOp) -> UOp|None:
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is not None}): return None
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
|
||||
|
||||
# fill new info
|
||||
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
|
||||
|
|
@ -504,8 +517,11 @@ def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
|||
|
||||
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
||||
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
|
||||
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
|
||||
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
|
||||
devs, b = to_tuple(g.src[1].arg), buf.buffer
|
||||
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
|
||||
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
|
||||
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
|
||||
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
|
||||
|
||||
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
|
||||
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast
|
||||
from typing import cast, Any, Callable
|
||||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -89,25 +89,25 @@ def memory_barrier(ctx):
|
|||
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
|
||||
acquire_mem(ctx)))
|
||||
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
|
||||
|
||||
def pm4_barrier(ctx): return memory_barrier(ctx)
|
||||
|
||||
def pm4_store(ctx, dst, val):
|
||||
if val.op is Ops.BINARY: return None
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
|
||||
def pm4_timestamp(ctx, dst):
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
ctx.pm4.int_sel__mec_release_mem__none)
|
||||
|
||||
def pm4_program(ctx, prg):
|
||||
data, info = prg.arg
|
||||
lib_gpu, args = prg.src
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
|
||||
args_addr = make_getaddr(args, ctx.device)
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
|
||||
args_addr = make_getaddr(args, ctx.devs)
|
||||
|
||||
user_regs = []
|
||||
if data.enable_private_segment_sgpr:
|
||||
|
|
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
|
|||
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
|
||||
|
||||
def sdma_copy(ctx, dst, src, copy):
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
|
||||
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
|
||||
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
|
||||
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
|
||||
|
|
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
|
|||
def sdma_wait(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
|
||||
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
|
||||
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
|
||||
def sdma_store(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
|
||||
return UOp(Ops.LINEAR, dtypes.void, (
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
|
||||
def sdma_timestamp(ctx, dst):
|
||||
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
|
||||
|
||||
pm_sdma_opsel = PatternMatcher([
|
||||
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
|
||||
|
|
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
|
|||
|
||||
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
|
||||
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
|
||||
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
|
||||
|
||||
def encode_queue(q:UOp) -> UOp|None:
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
|
||||
devs = to_tuple(q.arg[0])
|
||||
d = Device[(devs:=to_tuple(q.arg[0]))[0]]
|
||||
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
|
||||
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
|
||||
|
||||
pm_lower = PatternMatcher([
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue