mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
598 lines
32 KiB
Python
598 lines
32 KiB
Python
from __future__ import annotations
|
|
from typing import cast, Callable, TypeVar, Generic, Any
|
|
import struct, functools, time, collections, importlib, itertools, weakref
|
|
from dataclasses import replace, dataclass, field
|
|
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, DEBUG, dedup, flatten, pluralize
|
|
from tinygrad.helpers import to_tuple, round_up
|
|
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
|
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
|
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
|
from tinygrad.dtype import dtypes, AddrSpace
|
|
from tinygrad.runtime.support.hcq import MMIOInterface
|
|
from tinygrad.renderer import Renderer, Estimates
|
|
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
|
from tinygrad.engine.jit import DepsTracker
|
|
|
|
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
|
|
|
|
class HCQ2Compiled(Compiled):
|
|
timestamp_divider: float = 1000.0 # GPU timestamp counter ticks per microsecond; override per device
|
|
|
|
def __init__(self, device:str, allocator:'HCQAllocator', compilers:list[type[Renderer]], runtime, can_recover:bool=False, arch=None):
|
|
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
|
|
|
# default pm bufferize
|
|
self.pm_bufferize = PatternMatcher([
|
|
(UPat(Ops.BUFFER, tag="timeline_signal"), lambda ctx: ctx.timeline_signal()),
|
|
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
|
|
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
|
|
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
|
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
|
])
|
|
|
|
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
|
|
|
|
@functools.cache
|
|
def timeline_signal(self, queue:str|None=None, init_value:int=0) -> Buffer:
|
|
buf = Buffer(self.device, 1, dtypes.uint64, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
|
buf._buf.cpu_view().mv.cast('Q')[0] = init_value
|
|
return buf
|
|
|
|
@functools.cache
|
|
def timeline_value(self, queue:str|None=None, init_value:int=1) -> Buffer:
|
|
buf = Buffer("CPU", 1, dtypes.uint64, preallocate=True)
|
|
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = init_value
|
|
return buf
|
|
|
|
@functools.cached_property
|
|
def timestamps_buf(self) -> Buffer:
|
|
return Buffer(self.device, 0x1000, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
|
|
|
|
def synchronize(self, timeout:int|None=None):
|
|
if not hasattr(self, 'iface'): return
|
|
sig = self.timeline_signal()._buf.cpu_view().mv.cast('Q')
|
|
tl = self.timeline_value().as_memoryview(force_zero_copy=True).cast('Q')
|
|
st = time.perf_counter()
|
|
while sig[0] < tl[0] - 1:
|
|
if time.perf_counter() - st > (timeout or 3000) / 1000: self.on_device_hang()
|
|
|
|
def device_props(self) -> dict[str,Any]: return {} # to be overridden if needed. dict keys are backend dependent.
|
|
|
|
def count(self) -> int: return self.iface.count if hasattr(self, 'iface') else 1
|
|
|
|
def _select_iface(self):
|
|
assert (v:=getenv(k:=f'{type(self).__name__[:-6].upper()}_IFACE', "")) == "", \
|
|
f"{k}={v} is deprecated, use DEV={replace(DEV.target(type(self).__name__[:-6]), interface=v)} instead"
|
|
assert hasattr(self, "ifaces"), "must have ifaces to select an iface"
|
|
t = DEV.target(dev:=type(self).__name__[:-6])
|
|
filtered = select_by_name(self.ifaces, lambda i: i.__name__[:-5], t.interface, f"{dev} has no interface {t.interface!r}")
|
|
filtered = [i for i in filtered if t.interface.startswith("MOCK") or not i.__name__[:-5].startswith("MOCK")] # never fall back to mock ifaces
|
|
return select_first_inited([functools.partial(cast(Callable, iface), self, self.device_id) for iface in filtered],
|
|
f"No interface for {dev}:{self.device_id} is available")
|
|
|
|
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU"
|
|
|
|
def finalize(self):
|
|
try: self.synchronize() # try to finalize the device in any case
|
|
except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
|
|
|
|
# if the device has an interface, call device_fini to clean up resources
|
|
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
|
|
|
|
class HCQ2Buffer:
|
|
def __init__(self, va_addr:sint, size:int, meta:Any=None, _base:HCQ2Buffer|None=None, view:MMIOInterface|None=None, owner:HCQ2Compiled|None=None):
|
|
self.va_addr, self.size, self.meta, self._base, self.view, self.owner = va_addr, size, meta, _base, view, owner
|
|
|
|
def offset(self, offset:int=0, size:int|None=None) -> HCQ2Buffer:
|
|
return HCQ2Buffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, meta=self.meta,
|
|
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
|
|
|
|
def cpu_view(self) -> MMIOInterface:
|
|
assert self.view is not None, "buffer has no cpu_view"
|
|
return self.view
|
|
|
|
@property
|
|
def base(self) -> HCQ2Buffer: return self._base or self
|
|
|
|
class HCQAllocator(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
|
|
def _map(self, buf:HCQ2Buffer) -> HCQ2Buffer:
|
|
if not hasattr(self, '_do_map'): raise NotImplementedError("map failed: no method implemented")
|
|
return self._do_map(buf)
|
|
|
|
@suppress_finalizing
|
|
def _free(self, buf:HCQ2Buffer, options:BufferSpec|None=None):
|
|
self.dev.synchronize()
|
|
if options is not None and options.external_ptr is not None: return
|
|
if hasattr(self, '_do_free'): self._do_free(buf, options)
|
|
|
|
def _unmap(self, mb):
|
|
self.dev.synchronize()
|
|
self.dev.iface.free(mb)
|
|
|
|
def _offset(self, buf, size:int, offset:int) -> HCQ2Buffer: return buf.offset(offset=offset, size=size)
|
|
|
|
def _wrap(self, dev:str, sz:int, opaque:HCQ2Buffer) -> Buffer:
|
|
return Buffer(dev, sz, dtypes.uint8, opaque=opaque, options=BufferSpec(external_ptr=1))
|
|
|
|
def _copy(self, dst:Buffer, src:Buffer):
|
|
from tinygrad.engine.realize import run_linear
|
|
su = UOp.from_buffer(src)
|
|
run_linear(UOp(Ops.LINEAR, dtypes.void, (su.copy_to_device(dst.device).call(UOp.from_buffer(dst), su),)), update_stats=False)
|
|
|
|
def _copyin(self, dest:HCQ2Buffer, src:memoryview):
|
|
s = Buffer(self.dev.device, len(src), dtypes.uint8, options=BufferSpec(host=True), preallocate=True)
|
|
s._buf.cpu_view()[:len(src)] = src
|
|
self._copy(self._wrap(self.dev.device, len(src), dest), s)
|
|
|
|
def _copyout(self, dest:memoryview, src:HCQ2Buffer):
|
|
d = Buffer(self.dev.device, len(dest), dtypes.uint8, options=BufferSpec(host=True), preallocate=True)
|
|
self._copy(d, self._wrap(self.dev.device, len(dest), src))
|
|
self.dev.synchronize()
|
|
dest[:] = d._buf.cpu_view()[:len(dest)]
|
|
|
|
# def _as_buffer(self, buf): return buf.cpu_view().mv
|
|
|
|
def unwrap_after(uop):
|
|
while uop.op is Ops.AFTER: uop = uop.src[0]
|
|
return uop
|
|
|
|
def make_getaddr(u, device=None):
|
|
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
|
|
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=device or to_tuple(u.device)[0])))
|
|
|
|
def make_ins(op, *srcs):
|
|
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
|
|
|
def make_cmdbuf(lin, devs, tag):
|
|
blob, patches = b'', []
|
|
for s in (s for ins in lin.src for s in ins.src):
|
|
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
|
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
|
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
|
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
|
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
|
|
|
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
|
|
|
def make_signal(devs, queue=None, sentinel=False):
|
|
return UOp.new_buffer(devs, 1, dtypes.uint64).rtag("sentinel_signal" if sentinel else (queue, "timeline_signal") if queue else "timeline_signal")
|
|
def make_signal_value(devs, queue=None): return UOp.new_buffer(devs, 1, dtypes.uint64).rtag((queue, "timeline_value") if queue else "timeline_value")
|
|
|
|
# *****************
|
|
# 0. helpers
|
|
|
|
HCQ_DEVS = frozenset(("AMD",))
|
|
HCQ_P2P_DEVS = HCQ_DEVS | frozenset(("CPU",))
|
|
|
|
def all_devices_in(d:Any, c:frozenset[str]) -> bool: return {x.split(":")[0] for x in to_tuple(d)} <= c
|
|
|
|
@dataclass(frozen=True)
|
|
class HCQInfo:
|
|
name:str = ""
|
|
estimates:Estimates = Estimates()
|
|
outs:tuple[int, ...] = ()
|
|
devs:tuple[str, ...] = ()
|
|
|
|
params:tuple[int, ...] = ()
|
|
inputs:int|None = None
|
|
|
|
@staticmethod
|
|
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
|
|
|
|
# *****************
|
|
# 1.1. prep runtimes: staging copies
|
|
|
|
def _need_staging(a, b): return all_devices_in(a.device, HCQ_DEVS) and not all_devices_in(b.device, HCQ_P2P_DEVS)
|
|
|
|
def stage_copy(dst:UOp, src:UOp) -> UOp|None:
|
|
if not (_need_staging(src, dst) or _need_staging(dst, src)): return None
|
|
|
|
stage = UOp.new_buffer("CPU", src.buffer.nbytes, dtypes.uint8)
|
|
return UOp(Ops.LINEAR, dtypes.void, (src.copy_to_device("CPU").call(stage, src), stage.copy_to_device(dst.device).call(dst, stage)))
|
|
pm_insert_copy_staging = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.COPY), UPat(name="dst"), UPat(name="src"))), stage_copy)])
|
|
|
|
# *****************
|
|
# 1.2. prep runtimes: programs/kernargs
|
|
|
|
@functools.cache
|
|
def get_pm_prep_program(name:str) -> PatternMatcher|None:
|
|
try:
|
|
importlib.import_module(f'tinygrad.runtime.ops_{name.lower()}') # TODO: remove that
|
|
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_prep_program
|
|
except ImportError: return None
|
|
|
|
def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
|
dev = call.src[1].device
|
|
if (pm:=get_pm_prep_program(to_tuple(dev)[0].split(":")[0])) is None or (lowered:=pm.rewrite(prg)) is None: return None
|
|
|
|
data, image_bytes = lowered
|
|
buf = UOp.new_buffer(dev, len(image_bytes), dtypes.uint8).rtag("program")
|
|
blob = UOp(Ops.BINARY, dtypes.void, src=(), arg=image_bytes)
|
|
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
|
|
|
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
|
data, info = prg.arg
|
|
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
|
|
dtypes.uint64) for i,gi in enumerate(info.globals)] \
|
|
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
|
|
|
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
|
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
|
|
|
|
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
|
|
|
|
pm_prep_runtime = PatternMatcher([
|
|
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
|
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.BINARY)), name="prg"),),
|
|
name="call", allow_any_len=True), prep_program),
|
|
|
|
# lower kernargs (PROGRAM.src[0] is now AFTER(BUFFER, COPY) — the lowered program image)
|
|
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.BUFFER).or_after(),), name="prg"),), name="call", allow_any_len=True), prep_kernargs),
|
|
])
|
|
|
|
# *****************
|
|
# 2. lowering to hcq ir
|
|
|
|
def make_submit(*cmds, devs:str|tuple[str, ...], queue:str) -> UOp:
|
|
devs:tuple[str, ...] = to_tuple(devs)
|
|
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, dtypes.void, src=tuple(cmds), arg=(devs, queue)),), arg="submit")
|
|
|
|
def lower_program(call:UOp, prg:UOp) -> UOp:
|
|
return make_submit(prg, devs=call.src[1].device, queue="COMPUTE:0").sink().call(*call.src[1:], aux=call.arg.aux).rtag("hcq")
|
|
|
|
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
|
|
dst, src = call.src[1], call.src[2]
|
|
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
|
|
|
|
cp_op = UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes)
|
|
return make_submit(cp_op, devs=hcq_dev, queue="COPY:0").sink().call(*call.src[1:], aux=HCQInfo.from_call(call)).rtag("hcq")
|
|
|
|
pm_lower_ops = PatternMatcher([
|
|
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.BUFFER).or_after(), UPat(Ops.BUFFER).or_after()), name="prg"),),
|
|
name="call", allow_any_len=True), lower_program),
|
|
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="copy"),), name="call", allow_any_len=True), lower_copy),
|
|
])
|
|
|
|
# *****************
|
|
# 3.1. deps tracking
|
|
# device.timeline_signal/value are the per-device schedule epoch. Before a schedule queue accesses memory owned by device N for the first time,
|
|
# it waits for device[N].timeline_signal >= device[N].timeline_value - 1. This orders the schedule after all prior schedules that touched device N.
|
|
#
|
|
# queue.timeline_signal/value are per-queue progress counters used only inside a schedule.
|
|
# Only the owner queue signals its queue.timeline_signal. Values are monotonic.
|
|
#
|
|
# At schedule end, one finalizer queue per touched device[N] waits for every active queue on device[N] to reach its schedule-local
|
|
# final queue.timeline value, then signals device[N].timeline_signal with the schedule's reserved device epoch. After that, buffers/transients
|
|
# for device N from this schedule are safe for the next schedule
|
|
#
|
|
# C programs reserve and bump timeline values, then patch command buffers with the concrete wait/signal values.
|
|
|
|
@dataclass
|
|
class DepsCtx:
|
|
deps:DepsTracker = field(default_factory=DepsTracker)
|
|
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
|
|
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
|
|
params:dict[tuple[int, int], Buffer] = field(default_factory=dict)
|
|
|
|
def get_dep_buf(ctx:DepsCtx, u:UOp, lane:int) -> Buffer:
|
|
# TODO: should this be a part of DepsTracker?
|
|
if u.op is Ops.PARAM: return ctx.params.setdefault((u.arg.slot, lane), Buffer("NULL", u.max_numel(), u.dtype.base))
|
|
if u.op is Ops.MSTACK: return get_dep_buf(ctx, u.src[lane], 0)
|
|
if u.op in (Ops.SLICE, Ops.MSELECT): return get_dep_buf(ctx, u.src[0], u.arg if u.op is Ops.MSELECT else lane)
|
|
return b.bufs[lane] if isinstance(b:=u.buffer, MultiBuffer) else b
|
|
|
|
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|
new_src = []
|
|
for call in linear.src:
|
|
if call.tag != "hcq":
|
|
new_src.append(call)
|
|
continue
|
|
|
|
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
|
|
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
|
|
|
|
# per-lane deps, tracked per (device, queue). skip self
|
|
dep_lanes:list[tuple[UOp, int]] = []
|
|
for lane, d in enumerate(qdevs):
|
|
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
|
|
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
|
|
|
|
# drop self-queue waits, queue self-orders
|
|
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
|
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
|
|
|
|
# keep latest dep per lane, group lanes
|
|
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
|
|
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
|
|
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
|
|
|
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
|
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
|
return linear.replace(src=tuple(new_src))
|
|
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
|
|
|
# *****************
|
|
# 3.2. finalizer
|
|
|
|
def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|
devs = tuple(dedup([d for q in queues for d in to_tuple(q.arg[0])]))
|
|
zero = UOp.const(dtypes.int, 0)
|
|
tl = make_signal_value(devs)
|
|
|
|
# queue is inc with deps
|
|
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
|
|
|
|
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
|
|
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
|
|
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
|
|
|
|
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
|
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
|
|
return UOp.barrier(*patches).sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
|
|
|
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
|
|
parts:dict[str, list[UOp]] = collections.defaultdict(list)
|
|
for d, q in ctx.last_per_queue.items(): parts[to_tuple(d[0])[0].split(':')[0]].append(q)
|
|
|
|
nbump = next(ctx.opid)
|
|
return linear.replace(src=linear.src + tuple([make_finalizer(queues, nbump) for queues in parts.values()]))
|
|
pm_add_finalizer = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), add_finalizer)])
|
|
|
|
# *****************
|
|
# 3.3. lower loads/stores
|
|
|
|
def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
|
|
|
|
waits = []
|
|
for lanes, dep in zip(deps.arg, deps.src[1:]):
|
|
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
|
|
ctx.add(dep.tag) # mark op to update signal.
|
|
|
|
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
|
|
lanes = set(lanes)
|
|
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
|
|
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
|
|
waits.append(sig.wait(val + dep.tag))
|
|
return cur.replace(src=tuple(waits) + cur.src)
|
|
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
|
|
|
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp|None:
|
|
if q.tag not in ctx: return None
|
|
devs, queue = q.arg
|
|
src = q.src + (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
|
return submit.replace(src=(q.replace(src=src, tag=None),))
|
|
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
|
|
|
|
# *****************
|
|
# 4.1. merge queues
|
|
|
|
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
|
|
|
|
def merge_sink(sinks:list[UOp]) -> UOp:
|
|
if len(sinks) == 1: return sinks[0]
|
|
submits = [get_submit(sink) for sink in sinks]
|
|
queues = [submit.src[0] for submit in submits]
|
|
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
|
|
for sink, submit in zip(sinks[:-1], submits[:-1]):
|
|
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
|
|
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
|
|
|
|
def merge_queues(linear:UOp) -> UOp:
|
|
new_src:list[UOp] = []
|
|
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
|
|
|
|
for call in linear.src:
|
|
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
|
|
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
|
|
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
|
|
continue
|
|
|
|
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
|
|
new_rec = ([new_sink], call.arg.aux)
|
|
if (old:=opened_qs.pop((devs, queue), None)) is not None:
|
|
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
|
|
else:
|
|
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
|
|
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
|
|
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
|
|
opened_qs[(devs, queue)] = new_rec
|
|
|
|
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
|
|
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
|
|
|
|
# *****************
|
|
# 4.2. global sync
|
|
|
|
def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
|
if (devs:=q.arg[0]) in ctx: return None
|
|
ctx.add(devs)
|
|
|
|
# some devices from a command buffer might be used for the first time this schedule, so we wait for their global timeline epoch.
|
|
wait = make_signal(devs).wait(make_signal_value(devs).index(UOp.const(dtypes.int, 0)) - 1)
|
|
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
|
|
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
|
|
|
# *****************
|
|
# 4.3. annotate exec devs
|
|
|
|
pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
|
lambda call: call.replace(arg=replace(call.arg, aux=replace(call.arg.aux, devs=get_submit(call.src[0]).src[0].arg[0]))))])
|
|
|
|
# *****************
|
|
# 4.4. replace params with per-submit input address loads
|
|
|
|
def replace_params(call:UOp) -> UOp|None:
|
|
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
|
|
|
|
# fill new info
|
|
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
|
|
|
|
inputs = UOp.new_buffer(get_submit(call.src[0]).src[0].arg[0], len(hcqinfo.params), dtypes.uint64).rtag("inputs")
|
|
|
|
slot2idx = {s:i for i,s in enumerate(hcqinfo.params)}
|
|
body = call.src[0].substitute({u:inputs.index(UOp.const(dtypes.int, slot2idx[s])).load() for u,s in params.items()})
|
|
|
|
return call.replace(src=(body, *call.src[1:], inputs), arg=replace(call.arg, aux=hcqinfo))
|
|
pm_replace_params = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), replace_params)])
|
|
|
|
# *****************
|
|
# 5.1. encode cmdbufs
|
|
|
|
@functools.cache
|
|
def get_pm_lower(name:str) -> PatternMatcher|None:
|
|
try:
|
|
importlib.import_module(f'tinygrad.runtime.ops_{name.lower()}') # TODO: remove that
|
|
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_lower
|
|
except ImportError: return None
|
|
|
|
def encode_cmdbuf(submit:UOp, lin:UOp) -> UOp|None:
|
|
if (pm:=get_pm_lower(to_tuple(lin.arg[0])[0].split(":")[0])) is None: return None
|
|
return pm.rewrite(submit)
|
|
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="lin"),), name="submit"), encode_cmdbuf)])
|
|
|
|
# *****************
|
|
# 5.2. lift patches to the command buffer (root)
|
|
|
|
def lift_patches_to_cmdbuf(cmdbuf:UOp) -> UOp|None:
|
|
if not (patches:=dedup(u for store in cmdbuf.src[1:] for u in store.toposort() if u.op is Ops.AFTER)): return None
|
|
deps = tuple(d for p in patches for d in p.src[1:])
|
|
return cmdbuf.replace(src=cmdbuf.src + deps).substitute({p: p.src[0] for p in patches})
|
|
pm_lift_patches_to_cmdbuf = PatternMatcher([
|
|
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, tag={"compute", "copy"}),), allow_any_len=True, name="cmdbuf"), lift_patches_to_cmdbuf),
|
|
])
|
|
|
|
# *****************
|
|
# 5.3. pack placeholders buffers
|
|
|
|
def pack_hcq_placeholders(call:UOp) -> UOp|None:
|
|
bufs = [b for b in call.src[0].toposort() if b.op is Ops.BUFFER and b.tag in (maxtags:={"scratch"}) | (sumtags:={"program", "kernargs"})]
|
|
|
|
off_per_buf:dict[UOp, int] = {}
|
|
size_per_tag:dict[str, int] = {}
|
|
for b in bufs:
|
|
if b.tag in maxtags: size_per_tag[b.tag] = max(size_per_tag.get(b.tag, 0), b.arg)
|
|
elif b.tag in sumtags:
|
|
off_per_buf[b] = round_up(size_per_tag.get(b.tag, 0), {"program": 0x1000}.get(b.tag, 128))
|
|
size_per_tag[b.tag] = off_per_buf[b] + b.arg
|
|
|
|
count_per_tag = collections.Counter(b.tag for b in bufs)
|
|
ref_bufs = {b.tag:b for b in bufs if count_per_tag[b.tag] > 1}
|
|
bases = {tag:UOp.new_buffer(b.src[1].arg, size_per_tag[tag], b.dtype).rtag(tag) for tag,b in ref_bufs.items()}
|
|
subs = {b:UOp(Ops.SLICE, b.dtype, (bases[b.tag], UOp.const(dtypes.weakint, off_per_buf.get(b, 0))), b.arg) for b in bufs if b.tag in bases}
|
|
return call.replace(src=(call.src[0].substitute(subs, walk=True), *call.src[1:])) if subs else None
|
|
pm_pack_placeholders = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), pack_hcq_placeholders)])
|
|
|
|
# *****************
|
|
# 5.4. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
|
|
|
def hold_call_buffers(call:UOp) -> UOp|None:
|
|
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
|
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
|
|
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
|
|
|
|
# *****************
|
|
# 6. bufferize placeholders: replace placeholders with real buffers.
|
|
|
|
def bufferize_buf(buf:UOp) -> UOp|None:
|
|
if buf.tag is None: return None
|
|
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), "CPU") for dev in to_tuple(buf.src[1].arg))
|
|
return make_mstack(uops)
|
|
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
|
|
|
# *****************
|
|
# 7. resolve patches
|
|
|
|
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
|
|
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
|
|
|
|
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
|
|
for b in (mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)): b.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
|
|
return UOp(Ops.NOOP)
|
|
|
|
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
|
for b, v in zip((bs:=mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)), val.src if val.op is Ops.STACK else (val,)*len(bs)):
|
|
struct.pack_into(f'<{v.dtype.fmt}', b.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * buf.dtype.base.itemsize, v.arg)
|
|
return UOp(Ops.NOOP)
|
|
|
|
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
|
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
|
|
devs, b = to_tuple(g.src[1].arg), buf.buffer
|
|
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
|
|
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
|
|
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
|
|
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
|
|
|
|
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
|
|
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
|
|
return UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * itemsize)
|
|
|
|
pm_resolve_patches = PatternMatcher([
|
|
# multi
|
|
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
|
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
|
|
|
# index on slice is index
|
|
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
|
|
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
|
|
|
|
# getaddr
|
|
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
|
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
|
|
|
# folders
|
|
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
|
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
|
fold_const_store),
|
|
]) + symbolic_simple
|
|
|
|
# *****************
|
|
# 8. callify hcq programs
|
|
|
|
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
|
|
if ref not in bufs: bufs.append(ref)
|
|
return UOp.placeholder((ref.buffer.size,), ref.dtype, bufs.index(ref))
|
|
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
|
|
|
|
def parametrize_host_buffers(call:UOp) -> UOp:
|
|
# preserve original order of args
|
|
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=list(get_call_arg_uops(call))), bottom_up=True, name="parametrize host buffers")
|
|
|
|
# move vars to new slots
|
|
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
|
|
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
|
|
|
|
return call.replace(src=(body, *bufs) + tuple(x for x in call.src[1:] if x.op is Ops.BIND))
|
|
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
|
|
|
def callify_hcq(call:UOp) -> UOp:
|
|
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
|
|
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
|
|
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), callify_hcq)])
|
|
|
|
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
|
|
def hcq_schedule(linear:UOp) -> UOp:
|
|
linear = graph_rewrite(linear, pm_insert_copy_staging + pm_flatten_linear, name="insert copy staging")
|
|
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
|
|
|
|
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
|
|
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), walk=True, name="schedule inner sync")
|
|
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
|
|
linear = graph_rewrite(linear, pm_add_inner_loads, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
|
|
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_annotate_devs, name="annotate devs")
|
|
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
|
|
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_pack_placeholders, walk=True, name="pack placeholders")
|
|
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
|
|
|
# realize starts from here
|
|
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, walk=True, name="bufferize placeholders", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
|
|
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
|
|
|
return linear
|