hcq graph to linear (#15888)

* hcq

* f

* f

* linter
This commit is contained in:
nimlgen 2026-04-24 12:42:49 +03:00 committed by GitHub
commit c0f77c2e1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 69 additions and 47 deletions

View file

@ -555,6 +555,21 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert jf.captured is not None
def test_multi_tensor_jit_graph_assign_updates_each_shard(self):
@TinyJit
def jf(out: Tensor) -> Tensor:
tmp = (Tensor.arange(4, dtype=dtypes.float).shard(devices_2, 0) + 1).contiguous().realize()
out.assign((tmp + 1).contiguous()).realize()
return out
out = Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous().realize()
expected = np.arange(4, dtype=np.float32) + 2
for _ in range(5):
out.assign(Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous()).realize()
jf(out)
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-5)
assert jf.captured is not None
def test_multi_tensor_jit_body(self):
@TinyJit
def jf() -> Tensor:

View file

@ -2,45 +2,44 @@ import collections, time
from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, suppress_finalizing, TracingKey, unwrap
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops, Variable
from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner
from tinygrad.runtime.ops_rdma import RDMACopyQueue
class HCQGraph(MultiGraphRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.devices = list(set(cast(HCQCompiled, d) for ji in self.jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
self.devices = list({cast(HCQCompiled, Device[b.device]) for (_,_,bufs,_) in self.calls for b in bufs})
# CPU Device is always last
self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
# Replace input buffers with variables.
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in self.jit_cache]
self.hcq_bufs = [[b._buf for b in bufs] for (_,_,bufs,_) in self.calls]
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
for (j,i), input_idx in self.input_replace.items():
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size) # Create fake buffer with variable
for j, replace in enumerate(self.uop_replace):
for pos, iidx in replace:
x = self.input_replace_to_var.setdefault((j,pos), UOp.variable(f"inp_{iidx}_{self.calls[j][0]}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
self.hcq_bufs[j][pos] = HCQBuffer(x, self.hcq_bufs[j][pos].size) # Create fake buffer with variable
# Allocate kernel args.
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
for prg in self.progs:
if prg is None: continue
kernargs_size[prg.dev] += round_up(prg._prg.kernargs_alloc_size, 16)
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_args: dict[int, HCQArgsState] = {}
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
for j, prg in enumerate(self.progs):
if prg is None: continue
argsbuf = self.kernargs_bufs[prg.dev].offset(kargs_alloc[prg.dev].alloc(prg._prg.kernargs_alloc_size, 16))
self.ji_args[j] = prg._prg.fill_kernargs(self.hcq_bufs[j], prg.p.vars, argsbuf)
# Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
@ -81,32 +80,34 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): self.dev_access[queue].add(dev)
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
self.input_replace_map: dict[HCQCompiled, set[tuple[int, int]]] = collections.defaultdict(set)
self.device_vars: dict[HCQCompiled, dict[str, int]] = {}
for j,ji in enumerate(self.jit_cache):
ji_devs = [cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs] if isinstance(ji.prg, BufferXfer) else []
for j, ((_, ast, bufs, device_vars), prg) in enumerate(zip(self.calls, self.progs)):
is_xfer = ast.op is Ops.COPY and hasattr(alc:=Device[bufs[0].device].allocator, '_transfer') and alc.supports_transfer \
and bufs[0].device.split(":")[0] == bufs[1].device.split(":")[0]
ji_devs = [cast(HCQCompiled, Device[b.device]) for b in bufs] if is_xfer else []
is_rdma = len(ji_devs) > 0 and not any(d._is_cpu() for d in ji_devs) and len(set(d.peer_group for d in ji_devs)) > 1
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
if prg is not None: enqueue_dev: HCQCompiled = prg.dev
else:
# For copy ops prioritize enqeueuing on the src device, so reverse the buffers.
for b in cast(list[Buffer], ji.bufs[::-1]):
for b in bufs[::-1]:
if (enqueue_dev:=cast(HCQCompiled, Device[b.device])).hw_copy_queue_t is not None: break
# set any fixedvars on the device
self.device_vars[enqueue_dev] = merge_dicts([self.device_vars.get(enqueue_dev, {}), ji.fixedvars])
if is_exec_prg: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], cast(CompiledRunner, ji.prg).p.runtimevars])
self.device_vars[enqueue_dev] = merge_dicts([self.device_vars.get(enqueue_dev, {}), device_vars])
if prg is not None: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], prg.p.runtimevars])
if is_exec_prg:
if prg is not None:
enqueue_queue = self.comp_queues[enqueue_dev]
elif is_rdma:
enqueue_queue = self.comp_queues[enqueue_dev]
rdma_key = (cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device]).rdma_dev(), enqueue_dev.rdma_dev())
rdma_key = (cast(HCQCompiled, Device[bufs[0].device]).rdma_dev(), enqueue_dev.rdma_dev())
self.rdma_queues.setdefault(rdma_key, RDMACopyQueue(enqueue_dev.rdma_dev()))
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
queue_idx = self.devices.index(cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device])) % self.num_copy_queues
queue_idx = self.devices.index(cast(HCQCompiled, Device[bufs[0].device])) % self.num_copy_queues
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx),
enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.kick_signals[enqueue_dev.peer_group], self.kickoff_var))
@ -115,19 +116,19 @@ class HCQGraph(MultiGraphRunner):
# Get dependencies based on input and output buffers.
if is_rdma:
src_qp, dest_qp = rdma_key[1].iface.connect(rdma_key[0])[:2]
sync_signals, opt_deps, rdeps = self._resolve_deps(ji.bufs[1:], [], enqueue_queue, enqueue_dev, out_signal, j,
is_copy=isinstance(ji.prg, BufferXfer), rdma_qp=src_qp)
peer_queue = self.comp_queues[peer_dev:=cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device])]
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs[1:], [], enqueue_queue, enqueue_dev, out_signal, j,
is_copy=is_xfer, rdma_qp=src_qp)
peer_queue = self.comp_queues[peer_dev:=cast(HCQCompiled, Device[bufs[0].device])]
peer_out_signal = self.signals.setdefault(peer_queue, self.pg_dev[peer_dev.peer_group].new_signal(value=0))
peer_sync_signals, peer_opt_deps, peer_rdeps = self._resolve_deps(ji.bufs[:1], [0], peer_queue, peer_dev, peer_out_signal, j,
is_copy=isinstance(ji.prg, BufferXfer), rdma_qp=dest_qp)
peer_sync_signals, peer_opt_deps, peer_rdeps = self._resolve_deps(bufs[:1], [0], peer_queue, peer_dev, peer_out_signal, j,
is_copy=is_xfer, rdma_qp=dest_qp)
self.rdma_deps[j] = (peer_queue, peer_sync_signals + peer_opt_deps, peer_out_signal, j + 1)
self.last_j[peer_queue] = j
else:
sync_signals, opt_deps, rdeps = self._resolve_deps(ji.bufs, cast(CompiledRunner, ji.prg).p.outs if is_exec_prg else [0], enqueue_queue,
enqueue_dev, out_signal, j, is_copy=isinstance(ji.prg, BufferXfer))
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs, prg.p.outs if prg is not None else [0], enqueue_queue,
enqueue_dev, out_signal, j, is_copy=is_xfer)
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if prg is not None else (j + 1))
# Collect profile information if profiling is enabled.
if PROFILE:
@ -135,9 +136,9 @@ class HCQGraph(MultiGraphRunner):
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=self.last_j[enqueue_queue]) is not None else j * 2
# Description based on the command.
prof_ji_desc = ji.prg._prg.name if is_exec_prg else TracingKey(f"{ji.bufs[1].device} -> {ji.bufs[0].device}", ret=ji.bufs[0].nbytes) # type: ignore
prof_ji_desc = prg._prg.name if prg is not None else TracingKey(f"{bufs[1].device} -> {bufs[0].device}", ret=bufs[0].nbytes) # type: ignore
prof_name = f"{enqueue_dev.device}:SDMA:{queue_idx}" if not is_exec_prg else enqueue_dev.device
prof_name = enqueue_dev.device if prg is not None else f"{enqueue_dev.device}:SDMA:{queue_idx}"
self.prof_graph_entries.append(ProfileGraphEntry(prof_name, prof_ji_desc, sig_st, j * 2 + 1))
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
@ -158,7 +159,7 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
.wait(self.kick_signals[dev.peer_group], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
for j,ji in enumerate(self.jit_cache):
for j, ((dev_idx, ast, bufs, _), prg) in enumerate(zip(self.calls, self.progs)):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
# Lazy allocate signals
@ -170,13 +171,13 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and j * 2 in self.prof_signal_is_used: enqueue_queue.timestamp(self.prof_signals[j * 2])
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
elif isinstance(ji.prg, BufferXfer) and len(set(cast(HCQCompiled, Device[cast(Buffer, b).device]).peer_group for b in ji.bufs)) > 1:
if prg is not None:
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1)))
elif j in self.rdma_deps:
dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j]
for sig, val in dest_deps: dest_queue.wait(sig, val)
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
dest, src = bufs[0], bufs[1]
dest_dev, src_dev = cast(HCQCompiled, Device[dest.device]), cast(HCQCompiled, Device[src.device])
dest_rdma, src_rdma = dest_dev.rdma_dev(), src_dev.rdma_dev()
@ -194,10 +195,11 @@ class HCQGraph(MultiGraphRunner):
dest_queue.signal(dest_out_signal, dest_out_val)
self.num_rdma_ops[(dest_rdma, src_rdma)] += 1
elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):
if (inprep_idx:=self.input_replace.get((j, bufid))) is not None: self.input_replace_map[enqueue_dev].add(inprep_idx)
elif ast.op is Ops.COPY:
dest, src = bufs[0], bufs[1]
uop_replace_j = dict(self.uop_replace[j])
for bufid in range(len(bufs)):
if (replace_iidx:=uop_replace_j.get(bufid)) is not None: self.input_replace_map[enqueue_dev].add((replace_iidx, dev_idx))
else: cast(HCQAllocator, enqueue_dev.allocator).map(self.hcq_bufs[j][bufid])
enqueue_queue.copy(self.hcq_bufs[j][0], self.hcq_bufs[j][1], dest.nbytes)
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
@ -261,7 +263,9 @@ class HCQGraph(MultiGraphRunner):
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False, input_uops=None) -> float|None:
# Map input buffers
for dev in self.devices:
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_buffers[idx_to_map]._buf)
for iidx, dev_idx in self.input_replace_map[dev]:
buf = b.bufs[dev_idx] if isinstance(b:=input_uops[iidx].buffer, MultiBuffer) else b
cast(HCQAllocator, dev.allocator).map(buf._buf)
# Wait and restore signals
self.kickoff_value += 1
@ -273,8 +277,11 @@ class HCQGraph(MultiGraphRunner):
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update buffers
for (j,i),input_idx in self.input_replace.items():
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_buffers[input_idx]._buf.va_addr
for j, replace in enumerate(self.uop_replace):
dev_idx = self.calls[j][0]
for pos, iidx in replace:
buf = b.bufs[dev_idx] if isinstance(b:=input_uops[iidx].buffer, MultiBuffer) else b
hcq_var_vals[self.input_replace_to_var[(j,pos)].expr] = buf._buf.va_addr
for (var, qp) in self.rdma_vars.values(): hcq_var_vals[var.expr] = qp.head
for q in self.rdma_queues.values(): q.submit(q.dev, hcq_var_vals)