hcq2: debug=2 info (#16569)

* hcq2: debug=2 info

* t

* x

* hcq2: debug=2 info

* x
This commit is contained in:
nimlgen 2026-06-11 19:52:01 +03:00 committed by GitHub
commit e5f498de3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 55 additions and 33 deletions

View file

@ -1,9 +1,8 @@
from __future__ import annotations
from typing import cast, Callable, TypeVar, Generic, Any, TYPE_CHECKING
from typing import cast, Callable, TypeVar, Generic, Any
import struct, functools, time, collections, importlib, itertools, weakref
from dataclasses import replace
if TYPE_CHECKING: from tinygrad.engine.realize import ExecContext
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
from tinygrad.uop.symbolic import symbolic_simple, symbolic
@ -12,7 +11,7 @@ from dataclasses import dataclass, field
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.renderer import Renderer, Estimates
from tinygrad.engine.realize import to_program, track_stats, get_call_arg_uops, resolve_params, pm_flatten_linear
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
from tinygrad.engine.jit import DepsTracker
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
@ -166,10 +165,17 @@ def make_signal_value(devs, queue=None): return UOp.new_buffer(devs, 1, dtypes.u
HCQ_DEVS = frozenset(("AMD",))
HCQ_P2P_DEVS = HCQ_DEVS | frozenset(("CPU",))
def to_tuple(d): return d if isinstance(d, tuple) else (d,)
def all_devices_in(d:Any, c:frozenset[str]) -> bool: return {x.split(":")[0] for x in to_tuple(d)} <= c
@dataclass(frozen=True)
class HCQInfo:
name:str = ""
estimates:Estimates = Estimates()
outs:tuple[int, ...] = ()
@staticmethod
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
# *****************
# 1.1. prep runtimes: staging copies
@ -199,7 +205,7 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
data, image_bytes = lowered
buf = UOp.new_buffer(dev, len(image_bytes), dtypes.uint8).rtag("program")
blob = UOp(Ops.BINARY, dtypes.void, src=(), arg=image_bytes)
return call.replace(src=(prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)),) + call.src[1:])
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg
@ -229,14 +235,14 @@ def make_submit(*cmds, devs:str|tuple[str, ...], queue:str) -> UOp:
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, dtypes.void, src=tuple(cmds), arg=(devs, queue)),), arg="submit")
def lower_program(call:UOp, prg:UOp) -> UOp:
return make_submit(prg, devs=call.src[1].device, queue="COMPUTE:0").sink().call(*call.src[1:]).rtag(("hcq", tuple(prg.arg[1].outs)))
return make_submit(prg, devs=call.src[1].device, queue="COMPUTE:0").sink().call(*call.src[1:], aux=call.arg.aux).rtag("hcq")
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
dst, src = call.src[1], call.src[2]
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
cp_op = UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes)
return make_submit(cp_op, devs=hcq_dev, queue="COPY:0").sink().call(*call.src[1:]).rtag(("hcq", (0,)))
return make_submit(cp_op, devs=hcq_dev, queue="COPY:0").sink().call(*call.src[1:], aux=HCQInfo.from_call(call)).rtag("hcq")
pm_lower_ops = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.BUFFER).or_after(), UPat(Ops.BUFFER).or_after()), name="prg"),),
@ -267,7 +273,7 @@ class DepsCtx:
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src = []
for call in linear.src:
if not isinstance(call.tag, tuple) or not call.tag[0] == "hcq":
if call.tag != "hcq":
new_src.append(call)
continue
@ -277,10 +283,10 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
deps = []
refs = [b.buffer for b in get_call_arg_uops(call)]
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.tag[1], new_q)
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:]), tag="hcq"))
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -295,7 +301,8 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]).sink().call().rtag("hcq")
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
parts:dict[str, list[UOp]] = collections.defaultdict(list)
@ -345,22 +352,24 @@ def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
def merge_queues(linear:UOp) -> UOp:
new_src:list[UOp] = []
opened_qs:dict[tuple[tuple[str, ...], str], UOp] = {} # (devs, queue) -> sink, kept in submit order
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
for call in linear.src:
if call.tag != "hcq":
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in list(opened_qs)] + [call]
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
continue
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
if (old_sink:=opened_qs.pop((devs, queue), None)) is not None:
new_sink = merge_sinks(old_sink, new_sink) # exact same queue: merge, and re-insert at the end
aux = call.arg.aux
if (old:=opened_qs.pop((devs, queue), None)) is not None:
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
else:
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
opened_qs[(devs, queue)] = new_sink
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
opened_qs[(devs, queue)] = (new_sink, aux)
return linear.replace(src=tuple(new_src + [sink.call().rtag('hcq') for sink in opened_qs.values()]))
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
# *****************
@ -467,8 +476,8 @@ def parametrize_host_buffers(call:UOp) -> UOp:
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
def callify_hcq(call:UOp) -> UOp:
sink = UOp.sink(call.src[0], arg=KernelInfo(name="hcq_submit", estimates=Estimates()), tag=1)
return to_program(sink, Device["CPU"].renderer).call(*call.src[1:])
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
@ -477,7 +486,7 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), name="schedule inner sync")
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), walk=True, name="schedule inner sync")
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
linear = graph_rewrite(linear, pm_add_inner_loads, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)

View file

@ -3,12 +3,12 @@ from typing import cast
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
assert sys.platform != 'win32'
from dataclasses import dataclass
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, to_tuple, make_getaddr, make_ins, make_cmdbuf
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, make_getaddr, make_ins, make_cmdbuf
from tinygrad.uop.ops import sint, UOp
from tinygrad.device import Compiled, BufferSpec, Buffer, Device
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar, TracingKey
from tinygrad.helpers import VIZ, ceildiv, unwrap, pluralize
from tinygrad.helpers import VIZ, ceildiv, unwrap, pluralize, to_tuple
from tinygrad.renderer.cstyle import HIPRenderer, HIPCCRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.runtime.autogen import kfd, hsa, sqtt, amdgpu_kd, amdgpu_drm

View file

@ -1,8 +1,8 @@
from __future__ import annotations
from typing import cast, Iterator, Any
from typing import cast, Iterator, Any, Sequence
import time, random, itertools, math, contextlib, weakref
from dataclasses import dataclass, replace, field
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, TRACEMETA, prod, flatten, Context, getenv
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, TRACEMETA, prod, flatten, Context, getenv, dedup, to_tuple
from tinygrad.helpers import BEAM, size_to_str, time_to_str, VALIDATE_WITH_CPU, PROFILE, ProfilePointEvent, cpu_events
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite, ProgramInfo
@ -22,17 +22,19 @@ def get_call_outs_ins(call:UOp) -> tuple[tuple[int, ...], tuple[int, ...]]:
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return (0,), tuple(range(1, len(get_call_arg_uops(call))))
return (), ()
def get_call_name(call:UOp, bufs:list[Buffer], var_vals:dict[str, int]|None=None) -> str:
def get_call_name(call:UOp, bufs:Sequence[Buffer|UOp], var_vals:dict[str, int]|None=None) -> str:
def _uop_sz_to_str(uop:UOp) -> str: return size_to_str(sym_infer(prod(uop.shape) * uop.dtype.itemsize, var_vals or {}))
def _dev_str(buf:Buffer|UOp) -> str: return ', '.join(d[:7] for d in to_tuple(buf.device))
ast, arg_uops = call.src[0], get_call_arg_uops(call)
if ast.op is Ops.PROGRAM: return ast.arg.name
if ast.op is Ops.SLICE:
offset = ast.src[1].arg * arg_uops[1].dtype.itemsize
return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {offset:<10d}", "yellow")
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {_dev_str(bufs[0]):>7s} <- {_dev_str(bufs[1]):7s}", "yellow")
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return colored(f"enc/dec {_uop_sz_to_str(arg_uops[0])}", "yellow")
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": return colored(f"batched {len(ast.src[0].src)}", "cyan")
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "hcq": return call.arg.aux.name
raise NotImplementedError("get_call_name is not implemented")
# **************** Stat ****************
@ -44,6 +46,7 @@ def estimate_uop(call:UOp) -> Estimates:
nbytes = prod(call.src[1].shape) * call.src[1].dtype.itemsize
return Estimates(lds=nbytes, mem=nbytes)
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": return get_graph_runtime(ast).estimates
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "hcq": return call.arg.aux.estimates
return Estimates()
first_run_cache:set[bytes] = set()
@ -205,6 +208,13 @@ def exec_graph(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
with track_stats(ctx, call, rt.device, [], ctx.var_vals) as t: t[0] = rt(ctx.input_uops, ctx.var_vals, wait=ctx.wait) # type: ignore[call-arg]
return t[0]
def exec_hcq(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
pm_exec.rewrite(call.replace(src=(ast,) + call.src[1:]), replace(ctx, update_stats=False))
for d in dedup(flatten([to_tuple(u.device) for u in resolve_params(call, ctx.input_uops)])):
with track_stats(ctx, call, d, [], ctx.var_vals):
if ctx.wait: Device[d].synchronize()
return None
# flatten LINEAR-in-LINEAR: any nested LINEAR child gets inlined into its parent's src
pm_flatten_linear = PatternMatcher([
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="lin"),
@ -239,6 +249,7 @@ pm_exec = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="ast"),), name="call", allow_any_len=True), exec_kernel),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="ast"),), name="call", allow_any_len=True), exec_graph),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="hcq", src=(UPat(Ops.PROGRAM, name="ast"),)),), name="call", allow_any_len=True), exec_hcq),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="validate", name="ast"),), name="call", allow_any_len=True), exec_validate),
])

View file

@ -45,6 +45,7 @@ def size_to_str(s:int) -> str: return next((f"{s / d:.2f} {pr}" for d,pr in [(1<
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s:str): return len(ansistrip(s))
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
def to_tuple(x:T|tuple[T, ...]) -> tuple[T, ...]: return x if isinstance(x, tuple) else (x,)
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l):
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]

View file

@ -1080,13 +1080,13 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION}
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
name:str|None=None, precompile:bool=False, precompile_backward:bool=False, aux:Any=None) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
if self.op in UOp._OPAQUE_CALL_BODIES:
return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward, aux))
# value-producing bodies are always wrapped in TUPLE so FUNCTION dtype is always void
body = self if self.op is Ops.TUPLE else UOp.maketuple(self)
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward, aux))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
@ -1159,8 +1159,9 @@ class CallInfo:
name: str|None = None
precompile: bool = False
precompile_backward: bool = False
aux: Any = None
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile, self.precompile_backward))
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile, self.precompile_backward, self.aux))
def __repr__(self):
gf = id(self.grad_fxn) if self.grad_fxn else None
return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})"