mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
cbf4946ea6
commit
c0f77c2e1c
2 changed files with 69 additions and 47 deletions
|
|
@ -555,6 +555,21 @@ class TestMultiTensor(unittest.TestCase):
|
|||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert jf.captured is not None
|
||||
|
||||
def test_multi_tensor_jit_graph_assign_updates_each_shard(self):
|
||||
@TinyJit
|
||||
def jf(out: Tensor) -> Tensor:
|
||||
tmp = (Tensor.arange(4, dtype=dtypes.float).shard(devices_2, 0) + 1).contiguous().realize()
|
||||
out.assign((tmp + 1).contiguous()).realize()
|
||||
return out
|
||||
|
||||
out = Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous().realize()
|
||||
expected = np.arange(4, dtype=np.float32) + 2
|
||||
for _ in range(5):
|
||||
out.assign(Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous()).realize()
|
||||
jf(out)
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-5)
|
||||
assert jf.captured is not None
|
||||
|
||||
def test_multi_tensor_jit_body(self):
|
||||
@TinyJit
|
||||
def jf() -> Tensor:
|
||||
|
|
|
|||
|
|
@ -2,45 +2,44 @@ import collections, time
|
|||
from typing import Any, cast
|
||||
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, suppress_finalizing, TracingKey, unwrap
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, Variable
|
||||
from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner
|
||||
from tinygrad.runtime.ops_rdma import RDMACopyQueue
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in self.jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
self.devices = list({cast(HCQCompiled, Device[b.device]) for (_,_,bufs,_) in self.calls for b in bufs})
|
||||
|
||||
# CPU Device is always last
|
||||
self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
|
||||
|
||||
# Replace input buffers with variables.
|
||||
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in self.jit_cache]
|
||||
self.hcq_bufs = [[b._buf for b in bufs] for (_,_,bufs,_) in self.calls]
|
||||
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
|
||||
|
||||
for (j,i), input_idx in self.input_replace.items():
|
||||
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size) # Create fake buffer with variable
|
||||
for j, replace in enumerate(self.uop_replace):
|
||||
for pos, iidx in replace:
|
||||
x = self.input_replace_to_var.setdefault((j,pos), UOp.variable(f"inp_{iidx}_{self.calls[j][0]}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
self.hcq_bufs[j][pos] = HCQBuffer(x, self.hcq_bufs[j][pos].size) # Create fake buffer with variable
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
||||
for prg in self.progs:
|
||||
if prg is None: continue
|
||||
kernargs_size[prg.dev] += round_up(prg._prg.kernargs_alloc_size, 16)
|
||||
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_args: dict[int, HCQArgsState] = {}
|
||||
|
||||
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
|
||||
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
|
||||
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
|
||||
for j, prg in enumerate(self.progs):
|
||||
if prg is None: continue
|
||||
argsbuf = self.kernargs_bufs[prg.dev].offset(kargs_alloc[prg.dev].alloc(prg._prg.kernargs_alloc_size, 16))
|
||||
self.ji_args[j] = prg._prg.fill_kernargs(self.hcq_bufs[j], prg.p.vars, argsbuf)
|
||||
|
||||
# Schedule Dependencies.
|
||||
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
||||
|
|
@ -81,32 +80,34 @@ class HCQGraph(MultiGraphRunner):
|
|||
|
||||
for dev, queue in self.comp_queues.items(): self.dev_access[queue].add(dev)
|
||||
|
||||
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
|
||||
self.input_replace_map: dict[HCQCompiled, set[tuple[int, int]]] = collections.defaultdict(set)
|
||||
self.device_vars: dict[HCQCompiled, dict[str, int]] = {}
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
ji_devs = [cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs] if isinstance(ji.prg, BufferXfer) else []
|
||||
for j, ((_, ast, bufs, device_vars), prg) in enumerate(zip(self.calls, self.progs)):
|
||||
is_xfer = ast.op is Ops.COPY and hasattr(alc:=Device[bufs[0].device].allocator, '_transfer') and alc.supports_transfer \
|
||||
and bufs[0].device.split(":")[0] == bufs[1].device.split(":")[0]
|
||||
ji_devs = [cast(HCQCompiled, Device[b.device]) for b in bufs] if is_xfer else []
|
||||
is_rdma = len(ji_devs) > 0 and not any(d._is_cpu() for d in ji_devs) and len(set(d.peer_group for d in ji_devs)) > 1
|
||||
|
||||
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
|
||||
if prg is not None: enqueue_dev: HCQCompiled = prg.dev
|
||||
else:
|
||||
# For copy ops prioritize enqeueuing on the src device, so reverse the buffers.
|
||||
for b in cast(list[Buffer], ji.bufs[::-1]):
|
||||
for b in bufs[::-1]:
|
||||
if (enqueue_dev:=cast(HCQCompiled, Device[b.device])).hw_copy_queue_t is not None: break
|
||||
|
||||
# set any fixedvars on the device
|
||||
self.device_vars[enqueue_dev] = merge_dicts([self.device_vars.get(enqueue_dev, {}), ji.fixedvars])
|
||||
if is_exec_prg: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], cast(CompiledRunner, ji.prg).p.runtimevars])
|
||||
self.device_vars[enqueue_dev] = merge_dicts([self.device_vars.get(enqueue_dev, {}), device_vars])
|
||||
if prg is not None: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], prg.p.runtimevars])
|
||||
|
||||
if is_exec_prg:
|
||||
if prg is not None:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
elif is_rdma:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
rdma_key = (cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device]).rdma_dev(), enqueue_dev.rdma_dev())
|
||||
rdma_key = (cast(HCQCompiled, Device[bufs[0].device]).rdma_dev(), enqueue_dev.rdma_dev())
|
||||
self.rdma_queues.setdefault(rdma_key, RDMACopyQueue(enqueue_dev.rdma_dev()))
|
||||
else:
|
||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||
queue_idx = self.devices.index(cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device])) % self.num_copy_queues
|
||||
queue_idx = self.devices.index(cast(HCQCompiled, Device[bufs[0].device])) % self.num_copy_queues
|
||||
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx),
|
||||
enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.kick_signals[enqueue_dev.peer_group], self.kickoff_var))
|
||||
|
||||
|
|
@ -115,19 +116,19 @@ class HCQGraph(MultiGraphRunner):
|
|||
# Get dependencies based on input and output buffers.
|
||||
if is_rdma:
|
||||
src_qp, dest_qp = rdma_key[1].iface.connect(rdma_key[0])[:2]
|
||||
sync_signals, opt_deps, rdeps = self._resolve_deps(ji.bufs[1:], [], enqueue_queue, enqueue_dev, out_signal, j,
|
||||
is_copy=isinstance(ji.prg, BufferXfer), rdma_qp=src_qp)
|
||||
peer_queue = self.comp_queues[peer_dev:=cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device])]
|
||||
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs[1:], [], enqueue_queue, enqueue_dev, out_signal, j,
|
||||
is_copy=is_xfer, rdma_qp=src_qp)
|
||||
peer_queue = self.comp_queues[peer_dev:=cast(HCQCompiled, Device[bufs[0].device])]
|
||||
peer_out_signal = self.signals.setdefault(peer_queue, self.pg_dev[peer_dev.peer_group].new_signal(value=0))
|
||||
peer_sync_signals, peer_opt_deps, peer_rdeps = self._resolve_deps(ji.bufs[:1], [0], peer_queue, peer_dev, peer_out_signal, j,
|
||||
is_copy=isinstance(ji.prg, BufferXfer), rdma_qp=dest_qp)
|
||||
peer_sync_signals, peer_opt_deps, peer_rdeps = self._resolve_deps(bufs[:1], [0], peer_queue, peer_dev, peer_out_signal, j,
|
||||
is_copy=is_xfer, rdma_qp=dest_qp)
|
||||
self.rdma_deps[j] = (peer_queue, peer_sync_signals + peer_opt_deps, peer_out_signal, j + 1)
|
||||
self.last_j[peer_queue] = j
|
||||
else:
|
||||
sync_signals, opt_deps, rdeps = self._resolve_deps(ji.bufs, cast(CompiledRunner, ji.prg).p.outs if is_exec_prg else [0], enqueue_queue,
|
||||
enqueue_dev, out_signal, j, is_copy=isinstance(ji.prg, BufferXfer))
|
||||
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs, prg.p.outs if prg is not None else [0], enqueue_queue,
|
||||
enqueue_dev, out_signal, j, is_copy=is_xfer)
|
||||
|
||||
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
|
||||
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if prg is not None else (j + 1))
|
||||
|
||||
# Collect profile information if profiling is enabled.
|
||||
if PROFILE:
|
||||
|
|
@ -135,9 +136,9 @@ class HCQGraph(MultiGraphRunner):
|
|||
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=self.last_j[enqueue_queue]) is not None else j * 2
|
||||
|
||||
# Description based on the command.
|
||||
prof_ji_desc = ji.prg._prg.name if is_exec_prg else TracingKey(f"{ji.bufs[1].device} -> {ji.bufs[0].device}", ret=ji.bufs[0].nbytes) # type: ignore
|
||||
prof_ji_desc = prg._prg.name if prg is not None else TracingKey(f"{bufs[1].device} -> {bufs[0].device}", ret=bufs[0].nbytes) # type: ignore
|
||||
|
||||
prof_name = f"{enqueue_dev.device}:SDMA:{queue_idx}" if not is_exec_prg else enqueue_dev.device
|
||||
prof_name = enqueue_dev.device if prg is not None else f"{enqueue_dev.device}:SDMA:{queue_idx}"
|
||||
self.prof_graph_entries.append(ProfileGraphEntry(prof_name, prof_ji_desc, sig_st, j * 2 + 1))
|
||||
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
|
||||
|
||||
|
|
@ -158,7 +159,7 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
||||
.wait(self.kick_signals[dev.peer_group], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for j, ((dev_idx, ast, bufs, _), prg) in enumerate(zip(self.calls, self.progs)):
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
# Lazy allocate signals
|
||||
|
|
@ -170,13 +171,13 @@ class HCQGraph(MultiGraphRunner):
|
|||
if PROFILE and j * 2 in self.prof_signal_is_used: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
|
||||
elif isinstance(ji.prg, BufferXfer) and len(set(cast(HCQCompiled, Device[cast(Buffer, b).device]).peer_group for b in ji.bufs)) > 1:
|
||||
if prg is not None:
|
||||
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1)))
|
||||
elif j in self.rdma_deps:
|
||||
dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j]
|
||||
for sig, val in dest_deps: dest_queue.wait(sig, val)
|
||||
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
dest, src = bufs[0], bufs[1]
|
||||
dest_dev, src_dev = cast(HCQCompiled, Device[dest.device]), cast(HCQCompiled, Device[src.device])
|
||||
dest_rdma, src_rdma = dest_dev.rdma_dev(), src_dev.rdma_dev()
|
||||
|
||||
|
|
@ -194,10 +195,11 @@ class HCQGraph(MultiGraphRunner):
|
|||
|
||||
dest_queue.signal(dest_out_signal, dest_out_val)
|
||||
self.num_rdma_ops[(dest_rdma, src_rdma)] += 1
|
||||
elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):
|
||||
if (inprep_idx:=self.input_replace.get((j, bufid))) is not None: self.input_replace_map[enqueue_dev].add(inprep_idx)
|
||||
elif ast.op is Ops.COPY:
|
||||
dest, src = bufs[0], bufs[1]
|
||||
uop_replace_j = dict(self.uop_replace[j])
|
||||
for bufid in range(len(bufs)):
|
||||
if (replace_iidx:=uop_replace_j.get(bufid)) is not None: self.input_replace_map[enqueue_dev].add((replace_iidx, dev_idx))
|
||||
else: cast(HCQAllocator, enqueue_dev.allocator).map(self.hcq_bufs[j][bufid])
|
||||
enqueue_queue.copy(self.hcq_bufs[j][0], self.hcq_bufs[j][1], dest.nbytes)
|
||||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
|
|
@ -261,7 +263,9 @@ class HCQGraph(MultiGraphRunner):
|
|||
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False, input_uops=None) -> float|None:
|
||||
# Map input buffers
|
||||
for dev in self.devices:
|
||||
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_buffers[idx_to_map]._buf)
|
||||
for iidx, dev_idx in self.input_replace_map[dev]:
|
||||
buf = b.bufs[dev_idx] if isinstance(b:=input_uops[iidx].buffer, MultiBuffer) else b
|
||||
cast(HCQAllocator, dev.allocator).map(buf._buf)
|
||||
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
|
|
@ -273,8 +277,11 @@ class HCQGraph(MultiGraphRunner):
|
|||
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update buffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_buffers[input_idx]._buf.va_addr
|
||||
for j, replace in enumerate(self.uop_replace):
|
||||
dev_idx = self.calls[j][0]
|
||||
for pos, iidx in replace:
|
||||
buf = b.bufs[dev_idx] if isinstance(b:=input_uops[iidx].buffer, MultiBuffer) else b
|
||||
hcq_var_vals[self.input_replace_to_var[(j,pos)].expr] = buf._buf.va_addr
|
||||
|
||||
for (var, qp) in self.rdma_vars.values(): hcq_var_vals[var.expr] = qp.head
|
||||
for q in self.rdma_queues.values(): q.submit(q.dev, hcq_var_vals)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue