get_runner -> get_runtime (#15967)

* get_runner -> get_runtime

* do not use get_runner

* fix

* remove get_tunner

* remove

* fix

* x
This commit is contained in:
nimlgen 2026-04-29 18:29:49 +03:00 committed by GitHub
commit 7787f76dcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 161 additions and 158 deletions

View file

@ -8,17 +8,14 @@ class TestMockGPUInvalidInstruction(unittest.TestCase):
test_code = '''
import struct
from tinygrad import Device, Tensor
from tinygrad.engine.realize import get_runner
from tinygrad.engine.realize import compile_linear
from tinygrad.runtime.ops_amd import AMDProgram
dev = Device["AMD"]
a = Tensor([1.0]).realize()
b = a + 1
si = b.schedule_linear().src[-1]
runner = get_runner(dev.device, si.src[0])
prg = runner._prg
lib = bytearray(prg.lib)
linear = compile_linear(b.schedule_linear())
lib = bytearray(linear.src[-1].src[0].src[4].arg)
# Find s_endpgm (0xBFB00000) and replace with V_MOVRELD_B32 (op=66) which has no pcode
# VOP1 encoding: bits[31:25]=0x7E, op=bits[16:9], so op=66 -> 66<<9 = 0x8400

View file

@ -3,7 +3,8 @@ from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, DEV, Context, ProfileRangeEvent, cpu_profile, cpu_events, ProfilePointEvent, dedup
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.realize import get_runner
from tinygrad.engine.realize import get_runtime
from tinygrad.codegen import to_program
MOCKGPU = DEV.interface.startswith("MOCK")
def _dev_base(d):
@ -46,13 +47,15 @@ class TestProfiler(unittest.TestCase):
TestProfiler.b = self.a + 1
si = self.b.schedule_linear().src[-1]
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.src[0])
TestProfiler.prg = to_program(si.src[0], TestProfiler.d0.renderer)
TestProfiler.runtime = get_runtime(TestProfiler.d0.device, TestProfiler.prg)
TestProfiler.b.uop.buffer.allocate()
def test_profile_kernel_run(self):
runner_name = TestProfiler.runner._prg.name
runner_name = TestProfiler.runtime.name
with helper_collect_profile(TestProfiler.d0) as profile:
TestProfiler.runner([TestProfiler.b.uop.buffer, TestProfiler.a.uop.buffer], var_vals={})
gs, ls = TestProfiler.prg.arg.launch_dims({})
TestProfiler.runtime(TestProfiler.b.uop.buffer._buf, TestProfiler.a.uop.buffer._buf, global_size=gs, local_size=ls)
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
@ -70,12 +73,13 @@ class TestProfiler(unittest.TestCase):
assert len(kernel_runs) == 1, "one kernel run is expected"
def test_profile_multiops(self):
runner_name = TestProfiler.runner._prg.name
runner_name = TestProfiler.runtime.name
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
with helper_collect_profile(TestProfiler.d0) as profile:
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
TestProfiler.runner([buf1, TestProfiler.a.uop.buffer], var_vals={})
gs, ls = TestProfiler.prg.arg.launch_dims({})
TestProfiler.runtime(buf1._buf, TestProfiler.a.uop.buffer._buf, global_size=gs, local_size=ls)
buf1.copyout(memoryview(bytearray(buf1.nbytes)))
evs = [x for x in profile if isinstance(x, ProfileRangeEvent) and x.device.startswith(TestProfiler.d0.device)]

View file

@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec
from tinygrad.runtime.support.hcq import HCQCompiled, HCQBuffer
from tinygrad.runtime.autogen import libc
from tinygrad.runtime.support.system import PCIIfaceBase
from tinygrad.engine.realize import get_runner, CompiledRunner
from tinygrad.engine.realize import get_runtime
from tinygrad.codegen import to_program
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad import Variable
@ -22,11 +22,12 @@ class TestHCQ(unittest.TestCase):
TestHCQ.b = self.a + 1
si = self.b.schedule_linear().src[-1]
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.src[0])
TestHCQ.prg = to_program(si.src[0], TestHCQ.d0.renderer)
TestHCQ.runtime = get_runtime(TestHCQ.d0.device, TestHCQ.prg)
TestHCQ.b.uop.buffer.allocate()
TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf])
TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf])
TestHCQ.kernargs_ba_ptr = TestHCQ.runtime.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf])
TestHCQ.kernargs_ab_ptr = TestHCQ.runtime.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf])
def setUp(self):
TestHCQ.d0.synchronize()
@ -114,7 +115,7 @@ class TestHCQ(unittest.TestCase):
# Test exec
def test_exec_one_kernel(self):
TestHCQ.d0.hw_compute_queue_t().exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
TestHCQ.d0.hw_compute_queue_t().exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
@ -128,8 +129,8 @@ class TestHCQ(unittest.TestCase):
q = TestHCQ.d0.hw_compute_queue_t()
q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
.exec(TestHCQ.runtime, TestHCQ.kernargs_ab_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
.signal(TestHCQ.d0.timeline_signal, virt_val)
for _ in range(100):
@ -141,11 +142,11 @@ class TestHCQ(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "No globals/locals on LLVM/CPU")
def test_exec_update(self):
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.prg.arg.global_size[1:])
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.prg.arg.local_size[1:])
q = TestHCQ.d0.hw_compute_queue_t()
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
q.exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {sint_global[0].expr: 1, sint_local[0].expr: 1})
@ -166,17 +167,17 @@ class TestHCQ(unittest.TestCase):
b = a + 1
si = b.schedule_linear().src[-1]
runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer),
Device.DEFAULT)
prg = to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer)
runtime = get_runtime(Device.DEFAULT, prg)
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
ctypes.memset(zb._buf.va_addr, 0, zb.nbytes)
kernargs = runner._prg.fill_kernargs([zt._buf, zb._buf])
kernargs = runtime.fill_kernargs([zt._buf, zb._buf])
q = TestHCQ.d0.hw_compute_queue_t()
q.memory_barrier() \
.exec(runner._prg, kernargs, (1,1,1), virt_local) \
.exec(runtime, kernargs, (1,1,1), virt_local) \
.signal(TestHCQ.d0.timeline_signal, virt_val)
for x in range(1, 4):
@ -330,7 +331,7 @@ class TestHCQ(unittest.TestCase):
def test_speed_exec_time(self):
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
.timestamp(sig_en) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
@ -470,12 +471,13 @@ class TestHCQ(unittest.TestCase):
def test_memory_barrier(self):
a = Tensor([0, 1], device=Device.DEFAULT, dtype=dtypes.int8).realize()
b = a + 1
runner = get_runner(TestHCQ.d0.device, b.schedule_linear().src[-1].src[0])
prg = to_program(b.schedule_linear().src[-1].src[0], TestHCQ.d0.renderer)
runtime = get_runtime(TestHCQ.d0.device, prg)
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
kernargs_ptr = runner._prg.fill_kernargs([buf1._buf, buf2._buf])
kernargs_ptr = runtime.fill_kernargs([buf1._buf, buf2._buf])
for i in range(255):
ctypes.memset(buf2._buf.va_addr, i, 2)
@ -483,7 +485,7 @@ class TestHCQ(unittest.TestCase):
# Need memory_barrier after direct write to vram
TestHCQ.d0.hw_compute_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
.memory_barrier() \
.exec(runner._prg, kernargs_ptr, runner.p.global_size, runner.p.local_size) \
.exec(runtime, kernargs_ptr, prg.arg.global_size, prg.arg.local_size) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1

View file

@ -2,7 +2,8 @@ import unittest, ctypes, struct, time, array
from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import to_mv, CI
from tinygrad.device import Buffer, BufferSpec
from tinygrad.engine.realize import get_runner
from tinygrad.engine.realize import get_runtime
from tinygrad.codegen import to_program
def _time_queue(q, d):
st = time.perf_counter()
@ -21,13 +22,14 @@ class TestHCQ(unittest.TestCase):
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
TestHCQ.b = self.a + 1
linear = self.b.schedule_linear()
TestHCQ.runner = get_runner(TestHCQ.d0.device, linear.src[-1].src[0])
TestHCQ.prg = to_program(linear.src[-1].src[0], TestHCQ.d0.renderer)
TestHCQ.runtime = get_runtime(TestHCQ.d0.device, TestHCQ.prg)
TestHCQ.b.uop.buffer.allocate()
# wow that's a lot of abstraction layers
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf)
TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf)
TestHCQ.kernargs_off = TestHCQ.runner._prg.kernargs_offset
TestHCQ.kernargs_size = TestHCQ.runner._prg.kernargs_alloc_size
TestHCQ.kernargs_off = TestHCQ.runtime.kernargs_offset
TestHCQ.kernargs_size = TestHCQ.runtime.kernargs_alloc_size
ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_off, TestHCQ.addr, len(TestHCQ.addr))
ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size+TestHCQ.kernargs_off, TestHCQ.addr2, len(TestHCQ.addr2))
@ -38,8 +40,8 @@ class TestHCQ(unittest.TestCase):
elif Device.DEFAULT == "NV":
from tinygrad.runtime.ops_nv import HWQueue, HWQueue
# nv need to copy constbuffer there as well
to_mv(TestHCQ.d0.kernargs_ptr, 0x160).cast('I')[:] = array.array('I', TestHCQ.runner._prg.constbuffer_0)
to_mv(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, 0x160).cast('I')[:] = array.array('I', TestHCQ.runner._prg.constbuffer_0)
to_mv(TestHCQ.d0.kernargs_ptr, 0x160).cast('I')[:] = array.array('I', TestHCQ.runtime.constbuffer_0)
to_mv(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, 0x160).cast('I')[:] = array.array('I', TestHCQ.runtime.constbuffer_0)
TestHCQ.compute_queue = HWQueue
TestHCQ.copy_queue = HWQueue
@ -53,11 +55,11 @@ class TestHCQ(unittest.TestCase):
temp_signal, temp_value = TestHCQ.d0._alloc_signal(value=0), 0
q = TestHCQ.compute_queue()
for _ in range(1000):
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(temp_signal, temp_value + 1).wait(temp_signal, temp_value + 1)
temp_value += 1
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(temp_signal, temp_value + 1).wait(temp_signal, temp_value + 1)
temp_value += 1
@ -71,10 +73,10 @@ class TestHCQ(unittest.TestCase):
def test_run_1000_times(self):
temp_signal = TestHCQ.d0._alloc_signal(value=0)
q = TestHCQ.compute_queue()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(temp_signal, 2).wait(temp_signal, 2)
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size,
TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size,
TestHCQ.prg.arg.local_size)
for _ in range(1000):
TestHCQ.d0._set_signal(temp_signal, 1)
q.submit(TestHCQ.d0)
@ -87,11 +89,11 @@ class TestHCQ(unittest.TestCase):
def test_run_to_3(self):
temp_signal = TestHCQ.d0._alloc_signal(value=0)
q = TestHCQ.compute_queue()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(temp_signal, 1).wait(temp_signal, 1)
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(temp_signal, 2).wait(temp_signal, 2)
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@ -101,7 +103,7 @@ class TestHCQ(unittest.TestCase):
def test_update_exec(self):
q = TestHCQ.compute_queue()
exec_cmd_idx = len(q)
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.update_exec(exec_cmd_idx, (1,1,1), (1,1,1))
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
@ -115,10 +117,10 @@ class TestHCQ(unittest.TestCase):
def test_bind_run(self):
temp_signal = TestHCQ.d0._alloc_signal(value=0)
q = TestHCQ.compute_queue()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(temp_signal, 2).wait(temp_signal, 2)
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size,
TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size,
TestHCQ.prg.arg.local_size)
q.bind(TestHCQ.d0)
for _ in range(1000):
TestHCQ.d0._set_signal(temp_signal, 1)
@ -133,7 +135,7 @@ class TestHCQ(unittest.TestCase):
def test_update_exec_binded(self):
q = TestHCQ.compute_queue()
exec_ptr = q.ptr()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.bind(TestHCQ.d0)
@ -170,7 +172,7 @@ class TestHCQ(unittest.TestCase):
def test_run_normal(self):
q = TestHCQ.compute_queue()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@ -201,7 +203,7 @@ class TestHCQ(unittest.TestCase):
def test_run_signal(self):
q = TestHCQ.compute_queue()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
@ -278,7 +280,7 @@ class TestHCQ(unittest.TestCase):
def test_interleave_compute_and_copy(self):
q = TestHCQ.compute_queue()
qc = TestHCQ.copy_queue()
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) # b = [1, 2]
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) # b = [1, 2]
q.signal(sig:=TestHCQ.d0._alloc_signal(value=0), value=1)
qc.wait(sig, value=1)
qc.copy(TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf, 8)
@ -315,7 +317,7 @@ class TestHCQ(unittest.TestCase):
for _ in range(40):
q = TestHCQ.compute_queue()
q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1)
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1

View file

@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.nn.state import get_state_dict
from tinygrad.device import Allocator, Compiled
from tinygrad.engine.realize import method_cache
from tinygrad.codegen import to_program_cache
from tinygrad.helpers import Profiling
class FakeProgram:
@ -31,8 +31,8 @@ class TestLLaMASpeed(unittest.TestCase):
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
print("assigned empty tensors, doing warmup")
def run_llama(st, empty_method_cache=True):
if empty_method_cache: method_cache.clear()
def run_llama(st, empty_cache=True):
if empty_cache: to_program_cache.clear()
tms = [time.perf_counter()]
for i in range(5):
model(Tensor([[1,2,3,4]]), i).realize()

View file

@ -1,7 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.schedule import schedule_cache
from tinygrad.engine.realize import method_cache
from tinygrad.codegen import to_program, to_program_cache
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
@ -71,7 +70,6 @@ if __name__ == "__main__":
# these caches will keep uops alive
schedule_cache.clear()
method_cache.clear()
to_program_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()

View file

@ -53,10 +53,11 @@ class _MXCSRContext:
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.device import Buffer, BufferSpec
from tinygrad.device import Buffer, BufferSpec, Device
from tinygrad.runtime.autogen import hsa
from tinygrad.helpers import Context, DEBUG, PROFILE, colored
from tinygrad.engine.realize import get_runner
from tinygrad.engine.realize import get_runtime
from tinygrad.codegen import to_program
from tinygrad.renderer.amd import decode_inst
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3
@ -2045,18 +2046,18 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = {
# PROGRAM DECODE AND COMPILATION
# ═══════════════════════════════════════════════════════════════════════════════
_canonical_runner_cache: list[tuple[type, int, int, int, object]] = [] # [(inst_type, base, mask, size, runner), ...]
_canonical_runner_cache: list[tuple[type, int, int, int, tuple[UOp, object]]] = [] # [(inst_type, base, mask, size, (prg, runtime)), ...]
@functools.cache
def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
"""Build and compile instruction to CompiledRunner. Cached by instruction bytes, with canonical dedup."""
"""Build and compile instruction to (prg, runtime). Cached by instruction bytes, with canonical dedup."""
inst = decode_inst(inst_bytes, arch)
inst_size = inst.size()
inst_int = int.from_bytes(inst_bytes[:inst_size], 'little')
# Check if instruction matches any cached canonical pattern (must also match instruction type to avoid variant conflicts)
for inst_type, base, mask, size, runner in _canonical_runner_cache:
if type(inst) is inst_type and inst_size == size and (inst_int & mask) == base: return runner
for inst_type, base, mask, size, entry in _canonical_runner_cache:
if type(inst) is inst_type and inst_size == size and (inst_int & mask) == base: return entry
# Look up handler by type, falling back to base classes for _LIT variants
handler = _INST_HANDLERS.get(type(inst))
@ -2075,9 +2076,10 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
# NOTE: renderer output is not reproducible because of _MXCSRContext. PROFILE=0 prevents emulator instruction runners from polluting profiling.
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0, PROFILE=0):
runner = get_runner('CPU', sink)
_canonical_runner_cache.append((type(inst), base, mask, size, runner))
return runner
prg = to_program(sink, Device['CPU'].renderer)
runtime = get_runtime('CPU', prg)
_canonical_runner_cache.append((type(inst), base, mask, size, (prg, runtime)))
return prg, runtime
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
@ -2208,10 +2210,10 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool, Inst]:
if pc not in program:
prev_len = len(_canonical_runner_cache)
runner, inst = _decode_at(pc, arch)
(prg, runtime), inst = _decode_at(pc, arch)
is_barrier = (isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS) or \
(isinstance(inst, (ir4.SOP1,)) and inst.op in _BARRIER_SOP1_OPS)
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier, inst)
program[pc] = (runtime.fxn, prg.arg.globals, is_barrier, inst)
if DEBUG >= 3:
msg = f"[emu] PC={pc - lib}: {inst!r}"
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)

View file

@ -321,7 +321,6 @@ class TestVizGC(unittest.TestCase):
# VIZ integrates with other parts of tinygrad
from tinygrad import Tensor, Device
from tinygrad.engine.realize import get_runner
class TestVizIntegration(unittest.TestCase):
# codegen supports rendering of code blocks
@ -725,8 +724,8 @@ class TestCfg(unittest.TestCase):
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="NULL"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
with Context(DEV=f"NULL::{self.arch}"):
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
runner = get_runner(out.device, out.schedule_linear().src[-1].src[0])
return amdgpu_cfg(runner.prg.src[4].arg, self.arch)
prg = to_program(out.schedule_linear().src[-1].src[0], Device[out.device].renderer)
return amdgpu_cfg(prg.src[4].arg, self.arch)
def test_simple(self):
k = Kernel(arch=self.arch)

View file

@ -5,7 +5,7 @@ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv,
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType, dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import capturing, CompiledRunner, Runner, Estimates, compile_linear, run_linear, get_runner, graph_cache, estimate_uop
from tinygrad.engine.realize import capturing, Runner, Estimates, compile_linear, run_linear, graph_cache, estimate_uop, get_runtime
from tinygrad.engine.realize import unwrap_multi, resolve_params
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.nn.state import get_parameters
@ -99,13 +99,13 @@ class GraphRunner(Runner):
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
self.linear = linear.src[0]
self.calls: list[tuple[int, UOp, list[Buffer], dict[str, int]]] = []
self.progs: list[CompiledRunner|None] = []
self.runtimes: list[Any|None] = []
self.uop_replace: list[list[tuple[int, int]]] = []
for call in self.linear.src:
replace = [(p, b.arg) for p, b in enumerate(b for b in call.src[1:] if b.op is not Ops.BIND) if b.op is Ops.PARAM]
for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))):
self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars))
self.progs.append(get_runner(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
self.runtimes.append(get_runtime(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
self.uop_replace.append(replace)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
@ -114,20 +114,20 @@ class GraphRunner(Runner):
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
crs = [(j, p, self.calls[j][3]) for j,p in enumerate(self.progs) if isinstance(p, CompiledRunner)]
self.vars = sorted({v.expr for _,p,dv in crs for v in p.p.vars if v.expr not in dv | p.p.runtimevars})
self.symbolic_dims = dedup(tuple(d) for _,p,_ in crs for d in (p.p.local_size, p.p.global_size) if d and is_sym_dim(d))
crs = [(j, self.calls[j][1].arg, self.calls[j][3]) for j in range(len(self.calls)) if self.calls[j][1].op is Ops.PROGRAM]
self.vars = sorted({v.expr for _,p,dv in crs for v in p.vars if v.expr not in dv | p.runtimevars})
self.symbolic_dims = dedup(tuple(d) for _,p,_ in crs for d in (p.local_size, p.global_size) if d and is_sym_dim(d))
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
for j,p,dv in crs:
if (replace:=[(i, self.vars.index(v.expr)) for i, v in enumerate(p.p.vars) if v.expr not in dv | p.p.runtimevars]):
if (replace:=[(i, self.vars.index(v.expr)) for i, v in enumerate(p.vars) if v.expr not in dv | p.runtimevars]):
self.var_vals_replace[j] = replace
global_dim_idx, local_dim_idx = find_symbolic_dim(p.p.global_size), find_symbolic_dim(p.p.local_size)
global_dim_idx, local_dim_idx = find_symbolic_dim(p.global_size), find_symbolic_dim(p.local_size)
if global_dim_idx is not None or local_dim_idx is not None:
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
assert p.p.local_size is not None
self.launch_dims_base[j] = (tuple(p.p.global_size), tuple(p.p.local_size))
assert p.local_size is not None
self.launch_dims_base[j] = (tuple(p.global_size), tuple(p.local_size))
estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates())

View file

@ -1,9 +1,8 @@
from typing import cast, Iterator
from typing import cast, Iterator, Any
import time, random, itertools, math, contextlib, weakref
from dataclasses import dataclass, replace, field
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
from tinygrad.helpers import prod, EMULATED_DTYPES, flatten
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, Metadata, TRACEMETA, TracingKey, prod, flatten
from tinygrad.helpers import BEAM, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite, ProgramInfo
from tinygrad.device import Device, Buffer, MultiBuffer
@ -44,6 +43,22 @@ def update_stats(display_name:str, device:str, estimates:Estimates, var_vals:dic
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}")
first_run_cache:set[bytes] = set()
@contextlib.contextmanager
def track_stats(ctx:"ExecContext", call:UOp, device:str, display_name:str, bufs:list[Buffer], var_vals:dict[str, int], outputs=(0,), inputs=(1,)):
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
timing: list[float|None] = [None]
if DEBUG >= 2: st = time.perf_counter()
yield timing
if not ctx.do_update_stats: return
if DEBUG >= 2 and timing[0] is None:
Device[device].synchronize()
timing[0] = time.perf_counter() - st
update_stats(display_name, device, estimate_uop(call), var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata,
first_run=call.src[0].key not in first_run_cache)
first_run_cache.add(call.src[0].key)
# **************** Runners ****************
class Runner:
@ -102,19 +117,15 @@ class CompiledRunner(Runner):
# **************** method cache ****************
method_cache: dict[tuple[str, type, bytes, tuple, bool], CompiledRunner] = {}
def get_runner(device:str, ast:UOp) -> CompiledRunner:
# TODO: this should be all context relevant to rendering
context = (NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
ckey = (device, type(Device[device].compiler), ast.key, context, False)
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(bret.prg, device)
else:
prg = to_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg, device)
return ret
runtime_cache: dict[tuple[bytes, str], Any] = {}
def get_runtime(device:str, ast:UOp):
assert ast.op is Ops.PROGRAM and isinstance(ast.arg, ProgramInfo), "get_runtime should only be called with a PROGRAM ast"
if (runtime:=runtime_cache.get(key:=(ast.key, device))) is None:
if DEBUG >= 3 and ast.src[0].arg.applied_opts: print(ast.src[0].arg.applied_opts)
if DEBUG >= 4: print(ast.src[3].arg)
if DEBUG >= 7: Device[device].compiler.disassemble(ast.src[4].arg)
runtime = runtime_cache[key] = Device[device].runtime(ast.arg.function_name, ast.src[4].arg, *ast.arg.aux, runtimevars=ast.arg.runtimevars)
return runtime
# **************** run linear ****************
@ -132,20 +143,6 @@ def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
return inputs[b.arg] if b.op is Ops.PARAM else b
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in call.src[1:] if b.op is not Ops.BIND]
@contextlib.contextmanager
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, bufs:list[Buffer], var_vals:dict[str, int],
outputs=(0,), inputs=(1,), first_run=False):
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
timing: list[float|None] = [None]
if DEBUG >= 2: st = time.perf_counter()
yield timing
if not ctx.do_update_stats: return
if DEBUG >= 2 and timing[0] is None:
Device[device].synchronize()
timing[0] = time.perf_counter() - st
update_stats(display_name, device, estimate_uop(call), var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run)
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
bufs = [b.buffer for b in resolved]
if not any(isinstance(b, MultiBuffer) for b in bufs): yield cast(list[Buffer], bufs), {}
@ -178,21 +175,21 @@ def exec_copy(ctx:ExecContext, call, ast):
def exec_kernel(ctx:ExecContext, call, ast):
for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)):
var_vals = {**ctx.var_vals, **device_vars}
prg = get_runner(bufs[0].device, ast)
prg_bufs = [bufs[i].ensure_allocated() for i in prg.p.globals]
with track_stats(ctx, call, prg.device, prg.display_name, prg_bufs, var_vals,
outputs=tuple(prg.p.outs), inputs=tuple(prg.p.ins), first_run=prg.first_run) as timing:
timing[0] = prg(prg_bufs, var_vals, wait=DEBUG >= 2)
prg.first_run = False
prg_bufs = [bufs[i].ensure_allocated() for i in ast.arg.globals]
rt = get_runtime(device:=bufs[0].device, ast)
global_size, local_size = ast.arg.launch_dims(var_vals)
with track_stats(ctx, call, device, ast.arg.name, prg_bufs, var_vals, outputs=ast.arg.outs, inputs=ast.arg.ins) as tm:
tm[0] = rt(*[b._buf for b in prg_bufs], global_size=global_size, local_size=local_size, vals=ast.arg.vals(var_vals), wait=DEBUG>=2)
def exec_validate(ctx:ExecContext, call, ast):
import numpy as np
for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)):
cpu_bufs, dev_bufs = bufs[:len(bufs)//2], bufs[len(bufs)//2:]
cpu_prg = get_runner("CPU", ast.src[0])
cpu_prg([cpu_bufs[i].ensure_allocated() for i in cpu_prg.p.globals], {**ctx.var_vals, **device_vars}, wait=False)
for i in cpu_prg.p.outs: np.testing.assert_allclose(dev_bufs[i].ensure_allocated().numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3)
bufs, dev_bufs = bufs[:len(bufs)//2], bufs[len(bufs)//2:]
var_vals = {**ctx.var_vals, **device_vars}
cpu_rt = get_runtime("CPU", prg:=to_program(ast.src[0], Device["CPU"].renderer))
global_size, local_size = prg.arg.launch_dims(var_vals)
cpu_rt(*[bufs[i].ensure_allocated()._buf for i in prg.arg.globals], global_size=global_size, local_size=local_size, vals=prg.arg.vals(var_vals))
for i in prg.arg.outs: np.testing.assert_allclose(dev_bufs[i].ensure_allocated().numpy(), bufs[i].numpy(), rtol=1e-3, atol=1e-3)
def exec_encdec(ctx:ExecContext, call, ast):
bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(call, ctx.input_uops)]

View file

@ -14,14 +14,14 @@ class CUDAGraph(MultiGraphRunner):
self.nodes: list[tuple[Any, ...]] = [] # list of tuple(graph node, node params, c_args/context, is memcpy)
self.graph = init_c_var(cuda.CUgraph, lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
for (dev_idx, ast, bufs, device_vars), prg in zip(self.calls, self.progs):
for (dev_idx, ast, bufs, device_vars), runtime in zip(self.calls, self.runtimes):
if ast.op is Ops.PROGRAM:
assert prg is not None
global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars})
assert runtime is not None
global_size, local_size = ast.arg.launch_dims({v: 0 for v in self.vars})
c_deps, new_node = self.new_node([b.base for b in bufs], prg.p.outs)
c_args, vargs = encode_args([b._buf for b in bufs], [device_vars.get(x.expr, 0) for x in prg.p.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(prg._prg.prg, *global_size, *local_size, 0,
c_deps, new_node = self.new_node([b.base for b in bufs], ast.arg.outs)
c_args, vargs = encode_args([b._buf for b in bufs], [device_vars.get(x.expr, 0) for x in ast.arg.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(runtime.prg, *global_size, *local_size, 0,
ctypes.cast(0, ctypes.POINTER(ctypes.c_void_p)), vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(c_deps or []), ctypes.byref(kern_params)))

View file

@ -27,19 +27,19 @@ class HCQGraph(MultiGraphRunner):
# Allocate kernel args.
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
for prg in self.progs:
if prg is None: continue
kernargs_size[prg.dev] += round_up(prg._prg.kernargs_alloc_size, 16)
for runtime in self.runtimes:
if runtime is None: continue
kernargs_size[runtime.dev] += round_up(runtime.kernargs_alloc_size, 16)
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_args: dict[int, HCQArgsState] = {}
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
for j, prg in enumerate(self.progs):
if prg is None: continue
argsbuf = self.kernargs_bufs[prg.dev].offset(kargs_alloc[prg.dev].alloc(prg._prg.kernargs_alloc_size, 16))
self.ji_args[j] = prg._prg.fill_kernargs(self.hcq_bufs[j], prg.p.vars, argsbuf)
for j, runtime in enumerate(self.runtimes):
if runtime is None: continue
argsbuf = self.kernargs_bufs[runtime.dev].offset(kargs_alloc[runtime.dev].alloc(runtime.kernargs_alloc_size, 16))
self.ji_args[j] = runtime.fill_kernargs(self.hcq_bufs[j], self.calls[j][1].arg.vars, argsbuf)
# Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
@ -83,13 +83,13 @@ class HCQGraph(MultiGraphRunner):
self.input_replace_map: dict[HCQCompiled, set[tuple[int, int]]] = collections.defaultdict(set)
self.device_vars: dict[HCQCompiled, dict[str, int]] = {}
for j, ((_, ast, bufs, device_vars), prg) in enumerate(zip(self.calls, self.progs)):
for j, ((_, ast, bufs, device_vars), runtime) in enumerate(zip(self.calls, self.runtimes)):
is_xfer = ast.op is Ops.COPY and hasattr(alc:=Device[bufs[0].device].allocator, '_transfer') and alc.supports_transfer \
and bufs[0].device.split(":")[0] == bufs[1].device.split(":")[0]
ji_devs = [cast(HCQCompiled, Device[b.device]) for b in bufs] if is_xfer else []
is_rdma = len(ji_devs) > 0 and not any(d._is_cpu() for d in ji_devs) and len(set(d.peer_group for d in ji_devs)) > 1
if prg is not None: enqueue_dev: HCQCompiled = prg.dev
if runtime is not None: enqueue_dev: HCQCompiled = runtime.dev
else:
# For copy ops prioritize enqeueuing on the src device, so reverse the buffers.
for b in bufs[::-1]:
@ -97,9 +97,9 @@ class HCQGraph(MultiGraphRunner):
# set any fixedvars on the device
self.device_vars[enqueue_dev] = merge_dicts([self.device_vars.get(enqueue_dev, {}), device_vars])
if prg is not None: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], prg.p.runtimevars])
if runtime is not None: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], ast.arg.runtimevars])
if prg is not None:
if runtime is not None:
enqueue_queue = self.comp_queues[enqueue_dev]
elif is_rdma:
enqueue_queue = self.comp_queues[enqueue_dev]
@ -125,10 +125,10 @@ class HCQGraph(MultiGraphRunner):
self.rdma_deps[j] = (peer_queue, peer_sync_signals + peer_opt_deps, peer_out_signal, j + 1)
self.last_j[peer_queue] = j
else:
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs, prg.p.outs if prg is not None else [0], enqueue_queue,
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs, ast.arg.outs if runtime is not None else [0], enqueue_queue,
enqueue_dev, out_signal, j, is_copy=is_xfer)
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if prg is not None else (j + 1))
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if runtime is not None else (j + 1))
# Collect profile information if profiling is enabled.
if PROFILE:
@ -136,9 +136,9 @@ class HCQGraph(MultiGraphRunner):
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=self.last_j[enqueue_queue]) is not None else j * 2
# Description based on the command.
prof_ji_desc = prg._prg.name if prg is not None else TracingKey(f"{bufs[1].device} -> {bufs[0].device}", ret=bufs[0].nbytes) # type: ignore
prof_ji_desc = runtime.name if runtime is not None else TracingKey(f"{bufs[1].device} -> {bufs[0].device}", ret=bufs[0].nbytes) # type: ignore
prof_name = enqueue_dev.device if prg is not None else f"{enqueue_dev.device}:SDMA:{queue_idx}"
prof_name = enqueue_dev.device if runtime is not None else f"{enqueue_dev.device}:SDMA:{queue_idx}"
self.prof_graph_entries.append(ProfileGraphEntry(prof_name, prof_ji_desc, sig_st, j * 2 + 1))
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
@ -159,7 +159,7 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
.wait(self.kick_signals[dev.peer_group], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
for j, ((dev_idx, ast, bufs, _), prg) in enumerate(zip(self.calls, self.progs)):
for j, ((dev_idx, ast, bufs, _), runtime) in enumerate(zip(self.calls, self.runtimes)):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
# Lazy allocate signals
@ -171,8 +171,8 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and j * 2 in self.prof_signal_is_used: enqueue_queue.timestamp(self.prof_signals[j * 2])
# Encode main commands based on ji type.
if prg is not None:
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1))) # type: ignore[arg-type]
if runtime is not None:
enqueue_queue.exec(runtime, self.ji_args[j], ast.arg.global_size or (1,1,1), ast.arg.local_size or (1,1,1)) # type: ignore[arg-type]
elif j in self.rdma_deps:
dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j]
for sig, val in dest_deps: dest_queue.wait(sig, val)

View file

@ -27,17 +27,17 @@ class MetalGraph(GraphRunner):
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.vars) else []
for j, ((_, _, bufs, _), prg, replace) in enumerate(zip(self.calls, self.progs, self.uop_replace)):
assert prg is not None
for j, ((_, ast, bufs, _), runtime, replace) in enumerate(zip(self.calls, self.runtimes, self.uop_replace)):
assert runtime is not None
icb_command = self.icb.indirectComputeCommandAtIndex(j).retained()
icb_command.setComputePipelineState(prg._prg.pipeline_state)
all_pipelines.append(prg._prg.pipeline_state)
icb_command.setComputePipelineState(runtime.pipeline_state)
all_pipelines.append(runtime.pipeline_state)
for i, b in enumerate(bufs):
if not any(pos == i for pos, _ in replace):
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i, v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.vars.index(v.expr)*4, len(bufs)+i)
global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars})
for i, v in enumerate(ast.arg.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.vars.index(v.expr)*4, len(bufs)+i)
global_size, local_size = ast.arg.launch_dims({v: 0 for v in self.vars})
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
icb_command.setBarrier()
@ -97,7 +97,7 @@ class MetalGraph(GraphRunner):
def collect_timestamps(self):
# create a graph event and evenly space each program
st, en = decimal.Decimal(self.command_buffer.GPUStartTime()) * 1000000, decimal.Decimal(self.command_buffer.GPUEndTime()) * 1000000
ents = [ProfileGraphEntry(self.device, prg._prg.name, i, i+1) for i, prg in enumerate(self.progs) if prg is not None]
ents = [ProfileGraphEntry(self.device, rt.name, i, i+1) for i, rt in enumerate(self.runtimes) if rt is not None]
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st + (en-st)/len(ents)*i for i in range(len(ents)+1)])]
def __del__(self):

View file

@ -983,11 +983,13 @@ class ProgramInfo:
@property
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] # type: ignore[arg-type]
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
def launch_dims(self, var_vals:dict[str, int]) -> tuple[tuple[int, ...], tuple[int, ...]|None]:
global_size = tuple([sym_infer(sz, var_vals) for sz in self.global_size]) # type: ignore[arg-type]
local_size = tuple([sym_infer(sz, var_vals) for sz in self.local_size]) if self.local_size is not None else None
return global_size, local_size
def vals(self, var_vals:dict[str, int]): return tuple(var_vals[k.expr] if k.expr not in self.runtimevars else None for k in self.vars)
@staticmethod
def from_sink(sink:UOp, aux:tuple=()) -> ProgramInfo:
_vars: list[UOp] = []