mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
get_runner -> get_runtime (#15967)
* get_runner -> get_runtime * do not use get_runner * fix * remove get_tunner * remove * fix * x
This commit is contained in:
parent
fb188c3c23
commit
7787f76dcc
14 changed files with 161 additions and 158 deletions
|
|
@ -8,17 +8,14 @@ class TestMockGPUInvalidInstruction(unittest.TestCase):
|
|||
test_code = '''
|
||||
import struct
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.realize import compile_linear
|
||||
from tinygrad.runtime.ops_amd import AMDProgram
|
||||
|
||||
dev = Device["AMD"]
|
||||
a = Tensor([1.0]).realize()
|
||||
b = a + 1
|
||||
si = b.schedule_linear().src[-1]
|
||||
runner = get_runner(dev.device, si.src[0])
|
||||
|
||||
prg = runner._prg
|
||||
lib = bytearray(prg.lib)
|
||||
linear = compile_linear(b.schedule_linear())
|
||||
lib = bytearray(linear.src[-1].src[0].src[4].arg)
|
||||
|
||||
# Find s_endpgm (0xBFB00000) and replace with V_MOVRELD_B32 (op=66) which has no pcode
|
||||
# VOP1 encoding: bits[31:25]=0x7E, op=bits[16:9], so op=66 -> 66<<9 = 0x8400
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from tinygrad import Device, Tensor, dtypes, TinyJit
|
|||
from tinygrad.helpers import CI, DEV, Context, ProfileRangeEvent, cpu_profile, cpu_events, ProfilePointEvent, dedup
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.realize import get_runtime
|
||||
from tinygrad.codegen import to_program
|
||||
|
||||
MOCKGPU = DEV.interface.startswith("MOCK")
|
||||
def _dev_base(d):
|
||||
|
|
@ -46,13 +47,15 @@ class TestProfiler(unittest.TestCase):
|
|||
TestProfiler.b = self.a + 1
|
||||
si = self.b.schedule_linear().src[-1]
|
||||
|
||||
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.src[0])
|
||||
TestProfiler.prg = to_program(si.src[0], TestProfiler.d0.renderer)
|
||||
TestProfiler.runtime = get_runtime(TestProfiler.d0.device, TestProfiler.prg)
|
||||
TestProfiler.b.uop.buffer.allocate()
|
||||
|
||||
def test_profile_kernel_run(self):
|
||||
runner_name = TestProfiler.runner._prg.name
|
||||
runner_name = TestProfiler.runtime.name
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
TestProfiler.runner([TestProfiler.b.uop.buffer, TestProfiler.a.uop.buffer], var_vals={})
|
||||
gs, ls = TestProfiler.prg.arg.launch_dims({})
|
||||
TestProfiler.runtime(TestProfiler.b.uop.buffer._buf, TestProfiler.a.uop.buffer._buf, global_size=gs, local_size=ls)
|
||||
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
|
||||
|
|
@ -70,12 +73,13 @@ class TestProfiler(unittest.TestCase):
|
|||
assert len(kernel_runs) == 1, "one kernel run is expected"
|
||||
|
||||
def test_profile_multiops(self):
|
||||
runner_name = TestProfiler.runner._prg.name
|
||||
runner_name = TestProfiler.runtime.name
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
TestProfiler.runner([buf1, TestProfiler.a.uop.buffer], var_vals={})
|
||||
gs, ls = TestProfiler.prg.arg.launch_dims({})
|
||||
TestProfiler.runtime(buf1._buf, TestProfiler.a.uop.buffer._buf, global_size=gs, local_size=ls)
|
||||
buf1.copyout(memoryview(bytearray(buf1.nbytes)))
|
||||
|
||||
evs = [x for x in profile if isinstance(x, ProfileRangeEvent) and x.device.startswith(TestProfiler.d0.device)]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec
|
|||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQBuffer
|
||||
from tinygrad.runtime.autogen import libc
|
||||
from tinygrad.runtime.support.system import PCIIfaceBase
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from tinygrad.engine.realize import get_runtime
|
||||
from tinygrad.codegen import to_program
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad import Variable
|
||||
|
|
@ -22,11 +22,12 @@ class TestHCQ(unittest.TestCase):
|
|||
TestHCQ.b = self.a + 1
|
||||
si = self.b.schedule_linear().src[-1]
|
||||
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.src[0])
|
||||
TestHCQ.prg = to_program(si.src[0], TestHCQ.d0.renderer)
|
||||
TestHCQ.runtime = get_runtime(TestHCQ.d0.device, TestHCQ.prg)
|
||||
TestHCQ.b.uop.buffer.allocate()
|
||||
|
||||
TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf])
|
||||
TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf])
|
||||
TestHCQ.kernargs_ba_ptr = TestHCQ.runtime.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf])
|
||||
TestHCQ.kernargs_ab_ptr = TestHCQ.runtime.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf])
|
||||
|
||||
def setUp(self):
|
||||
TestHCQ.d0.synchronize()
|
||||
|
|
@ -114,7 +115,7 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
# Test exec
|
||||
def test_exec_one_kernel(self):
|
||||
TestHCQ.d0.hw_compute_queue_t().exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
TestHCQ.d0.hw_compute_queue_t().exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
|
|
@ -128,8 +129,8 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
|
||||
.exec(TestHCQ.runtime, TestHCQ.kernargs_ab_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for _ in range(100):
|
||||
|
|
@ -141,11 +142,11 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "No globals/locals on LLVM/CPU")
|
||||
def test_exec_update(self):
|
||||
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
|
||||
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
|
||||
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.prg.arg.global_size[1:])
|
||||
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.prg.arg.local_size[1:])
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
|
||||
q.exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.submit(TestHCQ.d0, {sint_global[0].expr: 1, sint_local[0].expr: 1})
|
||||
|
|
@ -166,17 +167,17 @@ class TestHCQ(unittest.TestCase):
|
|||
b = a + 1
|
||||
si = b.schedule_linear().src[-1]
|
||||
|
||||
runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer),
|
||||
Device.DEFAULT)
|
||||
prg = to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer)
|
||||
runtime = get_runtime(Device.DEFAULT, prg)
|
||||
|
||||
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
ctypes.memset(zb._buf.va_addr, 0, zb.nbytes)
|
||||
kernargs = runner._prg.fill_kernargs([zt._buf, zb._buf])
|
||||
kernargs = runtime.fill_kernargs([zt._buf, zb._buf])
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.memory_barrier() \
|
||||
.exec(runner._prg, kernargs, (1,1,1), virt_local) \
|
||||
.exec(runtime, kernargs, (1,1,1), virt_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for x in range(1, 4):
|
||||
|
|
@ -330,7 +331,7 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_speed_exec_time(self):
|
||||
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
|
||||
TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.exec(TestHCQ.runtime, TestHCQ.kernargs_ba_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) \
|
||||
.timestamp(sig_en) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
|
||||
|
|
@ -470,12 +471,13 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_memory_barrier(self):
|
||||
a = Tensor([0, 1], device=Device.DEFAULT, dtype=dtypes.int8).realize()
|
||||
b = a + 1
|
||||
runner = get_runner(TestHCQ.d0.device, b.schedule_linear().src[-1].src[0])
|
||||
prg = to_program(b.schedule_linear().src[-1].src[0], TestHCQ.d0.renderer)
|
||||
runtime = get_runtime(TestHCQ.d0.device, prg)
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
|
||||
kernargs_ptr = runner._prg.fill_kernargs([buf1._buf, buf2._buf])
|
||||
kernargs_ptr = runtime.fill_kernargs([buf1._buf, buf2._buf])
|
||||
|
||||
for i in range(255):
|
||||
ctypes.memset(buf2._buf.va_addr, i, 2)
|
||||
|
|
@ -483,7 +485,7 @@ class TestHCQ(unittest.TestCase):
|
|||
# Need memory_barrier after direct write to vram
|
||||
TestHCQ.d0.hw_compute_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
.memory_barrier() \
|
||||
.exec(runner._prg, kernargs_ptr, runner.p.global_size, runner.p.local_size) \
|
||||
.exec(runtime, kernargs_ptr, prg.arg.global_size, prg.arg.local_size) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
|
|
|||
48
test/external/external_test_hcq.py
vendored
48
test/external/external_test_hcq.py
vendored
|
|
@ -2,7 +2,8 @@ import unittest, ctypes, struct, time, array
|
|||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import to_mv, CI
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.realize import get_runtime
|
||||
from tinygrad.codegen import to_program
|
||||
|
||||
def _time_queue(q, d):
|
||||
st = time.perf_counter()
|
||||
|
|
@ -21,13 +22,14 @@ class TestHCQ(unittest.TestCase):
|
|||
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
TestHCQ.b = self.a + 1
|
||||
linear = self.b.schedule_linear()
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, linear.src[-1].src[0])
|
||||
TestHCQ.prg = to_program(linear.src[-1].src[0], TestHCQ.d0.renderer)
|
||||
TestHCQ.runtime = get_runtime(TestHCQ.d0.device, TestHCQ.prg)
|
||||
TestHCQ.b.uop.buffer.allocate()
|
||||
# wow that's a lot of abstraction layers
|
||||
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf)
|
||||
TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf)
|
||||
TestHCQ.kernargs_off = TestHCQ.runner._prg.kernargs_offset
|
||||
TestHCQ.kernargs_size = TestHCQ.runner._prg.kernargs_alloc_size
|
||||
TestHCQ.kernargs_off = TestHCQ.runtime.kernargs_offset
|
||||
TestHCQ.kernargs_size = TestHCQ.runtime.kernargs_alloc_size
|
||||
ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_off, TestHCQ.addr, len(TestHCQ.addr))
|
||||
ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size+TestHCQ.kernargs_off, TestHCQ.addr2, len(TestHCQ.addr2))
|
||||
|
||||
|
|
@ -38,8 +40,8 @@ class TestHCQ(unittest.TestCase):
|
|||
elif Device.DEFAULT == "NV":
|
||||
from tinygrad.runtime.ops_nv import HWQueue, HWQueue
|
||||
# nv need to copy constbuffer there as well
|
||||
to_mv(TestHCQ.d0.kernargs_ptr, 0x160).cast('I')[:] = array.array('I', TestHCQ.runner._prg.constbuffer_0)
|
||||
to_mv(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, 0x160).cast('I')[:] = array.array('I', TestHCQ.runner._prg.constbuffer_0)
|
||||
to_mv(TestHCQ.d0.kernargs_ptr, 0x160).cast('I')[:] = array.array('I', TestHCQ.runtime.constbuffer_0)
|
||||
to_mv(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, 0x160).cast('I')[:] = array.array('I', TestHCQ.runtime.constbuffer_0)
|
||||
TestHCQ.compute_queue = HWQueue
|
||||
TestHCQ.copy_queue = HWQueue
|
||||
|
||||
|
|
@ -53,11 +55,11 @@ class TestHCQ(unittest.TestCase):
|
|||
temp_signal, temp_value = TestHCQ.d0._alloc_signal(value=0), 0
|
||||
q = TestHCQ.compute_queue()
|
||||
for _ in range(1000):
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(temp_signal, temp_value + 1).wait(temp_signal, temp_value + 1)
|
||||
temp_value += 1
|
||||
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(temp_signal, temp_value + 1).wait(temp_signal, temp_value + 1)
|
||||
temp_value += 1
|
||||
|
||||
|
|
@ -71,10 +73,10 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_run_1000_times(self):
|
||||
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
||||
q = TestHCQ.compute_queue()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(temp_signal, 2).wait(temp_signal, 2)
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size,
|
||||
TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size,
|
||||
TestHCQ.prg.arg.local_size)
|
||||
for _ in range(1000):
|
||||
TestHCQ.d0._set_signal(temp_signal, 1)
|
||||
q.submit(TestHCQ.d0)
|
||||
|
|
@ -87,11 +89,11 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_run_to_3(self):
|
||||
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
||||
q = TestHCQ.compute_queue()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(temp_signal, 1).wait(temp_signal, 1)
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(temp_signal, 2).wait(temp_signal, 2)
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
|
@ -101,7 +103,7 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_update_exec(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
exec_cmd_idx = len(q)
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.update_exec(exec_cmd_idx, (1,1,1), (1,1,1))
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
|
@ -115,10 +117,10 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_bind_run(self):
|
||||
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
||||
q = TestHCQ.compute_queue()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(temp_signal, 2).wait(temp_signal, 2)
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size,
|
||||
TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.prg.arg.global_size,
|
||||
TestHCQ.prg.arg.local_size)
|
||||
q.bind(TestHCQ.d0)
|
||||
for _ in range(1000):
|
||||
TestHCQ.d0._set_signal(temp_signal, 1)
|
||||
|
|
@ -133,7 +135,7 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_update_exec_binded(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
exec_ptr = q.ptr()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.bind(TestHCQ.d0)
|
||||
|
||||
|
|
@ -170,7 +172,7 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
def test_run_normal(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
|
@ -201,7 +203,7 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
def test_run_signal(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
|
@ -278,7 +280,7 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_interleave_compute_and_copy(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
qc = TestHCQ.copy_queue()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) # b = [1, 2]
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size) # b = [1, 2]
|
||||
q.signal(sig:=TestHCQ.d0._alloc_signal(value=0), value=1)
|
||||
qc.wait(sig, value=1)
|
||||
qc.copy(TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf, 8)
|
||||
|
|
@ -315,7 +317,7 @@ class TestHCQ(unittest.TestCase):
|
|||
for _ in range(40):
|
||||
q = TestHCQ.compute_queue()
|
||||
q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1)
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.exec(TestHCQ.runtime, TestHCQ.d0.kernargs_ptr, TestHCQ.prg.arg.global_size, TestHCQ.prg.arg.local_size)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
|
|
|||
6
test/external/external_test_speed_llama.py
vendored
6
test/external/external_test_speed_llama.py
vendored
|
|
@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad import Device
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.device import Allocator, Compiled
|
||||
from tinygrad.engine.realize import method_cache
|
||||
from tinygrad.codegen import to_program_cache
|
||||
from tinygrad.helpers import Profiling
|
||||
|
||||
class FakeProgram:
|
||||
|
|
@ -31,8 +31,8 @@ class TestLLaMASpeed(unittest.TestCase):
|
|||
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
|
||||
print("assigned empty tensors, doing warmup")
|
||||
|
||||
def run_llama(st, empty_method_cache=True):
|
||||
if empty_method_cache: method_cache.clear()
|
||||
def run_llama(st, empty_cache=True):
|
||||
if empty_cache: to_program_cache.clear()
|
||||
tms = [time.perf_counter()]
|
||||
for i in range(5):
|
||||
model(Tensor([[1,2,3,4]]), i).realize()
|
||||
|
|
|
|||
2
test/external/external_uop_gc.py
vendored
2
test/external/external_uop_gc.py
vendored
|
|
@ -1,7 +1,6 @@
|
|||
import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.schedule import schedule_cache
|
||||
from tinygrad.engine.realize import method_cache
|
||||
from tinygrad.codegen import to_program, to_program_cache
|
||||
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
|
|
@ -71,7 +70,6 @@ if __name__ == "__main__":
|
|||
|
||||
# these caches will keep uops alive
|
||||
schedule_cache.clear()
|
||||
method_cache.clear()
|
||||
to_program_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
_apply_reshape.cache_clear()
|
||||
|
|
|
|||
|
|
@ -53,10 +53,11 @@ class _MXCSRContext:
|
|||
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.device import Buffer, BufferSpec, Device
|
||||
from tinygrad.runtime.autogen import hsa
|
||||
from tinygrad.helpers import Context, DEBUG, PROFILE, colored
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.realize import get_runtime
|
||||
from tinygrad.codegen import to_program
|
||||
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3
|
||||
|
|
@ -2045,18 +2046,18 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = {
|
|||
# PROGRAM DECODE AND COMPILATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_canonical_runner_cache: list[tuple[type, int, int, int, object]] = [] # [(inst_type, base, mask, size, runner), ...]
|
||||
_canonical_runner_cache: list[tuple[type, int, int, int, tuple[UOp, object]]] = [] # [(inst_type, base, mask, size, (prg, runtime)), ...]
|
||||
|
||||
@functools.cache
|
||||
def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
||||
"""Build and compile instruction to CompiledRunner. Cached by instruction bytes, with canonical dedup."""
|
||||
"""Build and compile instruction to (prg, runtime). Cached by instruction bytes, with canonical dedup."""
|
||||
inst = decode_inst(inst_bytes, arch)
|
||||
inst_size = inst.size()
|
||||
inst_int = int.from_bytes(inst_bytes[:inst_size], 'little')
|
||||
|
||||
# Check if instruction matches any cached canonical pattern (must also match instruction type to avoid variant conflicts)
|
||||
for inst_type, base, mask, size, runner in _canonical_runner_cache:
|
||||
if type(inst) is inst_type and inst_size == size and (inst_int & mask) == base: return runner
|
||||
for inst_type, base, mask, size, entry in _canonical_runner_cache:
|
||||
if type(inst) is inst_type and inst_size == size and (inst_int & mask) == base: return entry
|
||||
|
||||
# Look up handler by type, falling back to base classes for _LIT variants
|
||||
handler = _INST_HANDLERS.get(type(inst))
|
||||
|
|
@ -2075,9 +2076,10 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
|||
|
||||
# NOTE: renderer output is not reproducible because of _MXCSRContext. PROFILE=0 prevents emulator instruction runners from polluting profiling.
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0, PROFILE=0):
|
||||
runner = get_runner('CPU', sink)
|
||||
_canonical_runner_cache.append((type(inst), base, mask, size, runner))
|
||||
return runner
|
||||
prg = to_program(sink, Device['CPU'].renderer)
|
||||
runtime = get_runtime('CPU', prg)
|
||||
_canonical_runner_cache.append((type(inst), base, mask, size, (prg, runtime)))
|
||||
return prg, runtime
|
||||
|
||||
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||||
|
|
@ -2208,10 +2210,10 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool, Inst]:
|
||||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
runner, inst = _decode_at(pc, arch)
|
||||
(prg, runtime), inst = _decode_at(pc, arch)
|
||||
is_barrier = (isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS) or \
|
||||
(isinstance(inst, (ir4.SOP1,)) and inst.op in _BARRIER_SOP1_OPS)
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier, inst)
|
||||
program[pc] = (runtime.fxn, prg.arg.globals, is_barrier, inst)
|
||||
if DEBUG >= 3:
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||||
|
|
|
|||
|
|
@ -321,7 +321,6 @@ class TestVizGC(unittest.TestCase):
|
|||
# VIZ integrates with other parts of tinygrad
|
||||
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
class TestVizIntegration(unittest.TestCase):
|
||||
# codegen supports rendering of code blocks
|
||||
|
|
@ -725,8 +724,8 @@ class TestCfg(unittest.TestCase):
|
|||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="NULL"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
with Context(DEV=f"NULL::{self.arch}"):
|
||||
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
runner = get_runner(out.device, out.schedule_linear().src[-1].src[0])
|
||||
return amdgpu_cfg(runner.prg.src[4].arg, self.arch)
|
||||
prg = to_program(out.schedule_linear().src[-1].src[0], Device[out.device].renderer)
|
||||
return amdgpu_cfg(prg.src[4].arg, self.arch)
|
||||
|
||||
def test_simple(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv,
|
|||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||
from tinygrad.engine.realize import capturing, CompiledRunner, Runner, Estimates, compile_linear, run_linear, get_runner, graph_cache, estimate_uop
|
||||
from tinygrad.engine.realize import capturing, Runner, Estimates, compile_linear, run_linear, graph_cache, estimate_uop, get_runtime
|
||||
from tinygrad.engine.realize import unwrap_multi, resolve_params
|
||||
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
|
@ -99,13 +99,13 @@ class GraphRunner(Runner):
|
|||
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
|
||||
self.linear = linear.src[0]
|
||||
self.calls: list[tuple[int, UOp, list[Buffer], dict[str, int]]] = []
|
||||
self.progs: list[CompiledRunner|None] = []
|
||||
self.runtimes: list[Any|None] = []
|
||||
self.uop_replace: list[list[tuple[int, int]]] = []
|
||||
for call in self.linear.src:
|
||||
replace = [(p, b.arg) for p, b in enumerate(b for b in call.src[1:] if b.op is not Ops.BIND) if b.op is Ops.PARAM]
|
||||
for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))):
|
||||
self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars))
|
||||
self.progs.append(get_runner(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
|
||||
self.runtimes.append(get_runtime(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
|
||||
self.uop_replace.append(replace)
|
||||
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
|
|
@ -114,20 +114,20 @@ class GraphRunner(Runner):
|
|||
|
||||
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
||||
|
||||
crs = [(j, p, self.calls[j][3]) for j,p in enumerate(self.progs) if isinstance(p, CompiledRunner)]
|
||||
self.vars = sorted({v.expr for _,p,dv in crs for v in p.p.vars if v.expr not in dv | p.p.runtimevars})
|
||||
self.symbolic_dims = dedup(tuple(d) for _,p,_ in crs for d in (p.p.local_size, p.p.global_size) if d and is_sym_dim(d))
|
||||
crs = [(j, self.calls[j][1].arg, self.calls[j][3]) for j in range(len(self.calls)) if self.calls[j][1].op is Ops.PROGRAM]
|
||||
self.vars = sorted({v.expr for _,p,dv in crs for v in p.vars if v.expr not in dv | p.runtimevars})
|
||||
self.symbolic_dims = dedup(tuple(d) for _,p,_ in crs for d in (p.local_size, p.global_size) if d and is_sym_dim(d))
|
||||
|
||||
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
||||
|
||||
for j,p,dv in crs:
|
||||
if (replace:=[(i, self.vars.index(v.expr)) for i, v in enumerate(p.p.vars) if v.expr not in dv | p.p.runtimevars]):
|
||||
if (replace:=[(i, self.vars.index(v.expr)) for i, v in enumerate(p.vars) if v.expr not in dv | p.runtimevars]):
|
||||
self.var_vals_replace[j] = replace
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(p.p.global_size), find_symbolic_dim(p.p.local_size)
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(p.global_size), find_symbolic_dim(p.local_size)
|
||||
if global_dim_idx is not None or local_dim_idx is not None:
|
||||
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||
assert p.p.local_size is not None
|
||||
self.launch_dims_base[j] = (tuple(p.p.global_size), tuple(p.p.local_size))
|
||||
assert p.local_size is not None
|
||||
self.launch_dims_base[j] = (tuple(p.global_size), tuple(p.local_size))
|
||||
|
||||
estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from typing import cast, Iterator
|
||||
from typing import cast, Iterator, Any
|
||||
import time, random, itertools, math, contextlib, weakref
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.helpers import prod, EMULATED_DTYPES, flatten
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, Metadata, TRACEMETA, TracingKey, prod, flatten
|
||||
from tinygrad.helpers import BEAM, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite, ProgramInfo
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
|
|
@ -44,6 +43,22 @@ def update_stats(display_name:str, device:str, estimates:Estimates, var_vals:dic
|
|||
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}")
|
||||
|
||||
first_run_cache:set[bytes] = set()
|
||||
@contextlib.contextmanager
|
||||
def track_stats(ctx:"ExecContext", call:UOp, device:str, display_name:str, bufs:list[Buffer], var_vals:dict[str, int], outputs=(0,), inputs=(1,)):
|
||||
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
|
||||
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
|
||||
timing: list[float|None] = [None]
|
||||
if DEBUG >= 2: st = time.perf_counter()
|
||||
yield timing
|
||||
if not ctx.do_update_stats: return
|
||||
if DEBUG >= 2 and timing[0] is None:
|
||||
Device[device].synchronize()
|
||||
timing[0] = time.perf_counter() - st
|
||||
update_stats(display_name, device, estimate_uop(call), var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata,
|
||||
first_run=call.src[0].key not in first_run_cache)
|
||||
first_run_cache.add(call.src[0].key)
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
class Runner:
|
||||
|
|
@ -102,19 +117,15 @@ class CompiledRunner(Runner):
|
|||
|
||||
# **************** method cache ****************
|
||||
|
||||
method_cache: dict[tuple[str, type, bytes, tuple, bool], CompiledRunner] = {}
|
||||
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
# TODO: this should be all context relevant to rendering
|
||||
context = (NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
|
||||
ckey = (device, type(Device[device].compiler), ast.key, context, False)
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(bret.prg, device)
|
||||
else:
|
||||
prg = to_program(ast, Device[device].renderer)
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg, device)
|
||||
return ret
|
||||
runtime_cache: dict[tuple[bytes, str], Any] = {}
|
||||
def get_runtime(device:str, ast:UOp):
|
||||
assert ast.op is Ops.PROGRAM and isinstance(ast.arg, ProgramInfo), "get_runtime should only be called with a PROGRAM ast"
|
||||
if (runtime:=runtime_cache.get(key:=(ast.key, device))) is None:
|
||||
if DEBUG >= 3 and ast.src[0].arg.applied_opts: print(ast.src[0].arg.applied_opts)
|
||||
if DEBUG >= 4: print(ast.src[3].arg)
|
||||
if DEBUG >= 7: Device[device].compiler.disassemble(ast.src[4].arg)
|
||||
runtime = runtime_cache[key] = Device[device].runtime(ast.arg.function_name, ast.src[4].arg, *ast.arg.aux, runtimevars=ast.arg.runtimevars)
|
||||
return runtime
|
||||
|
||||
# **************** run linear ****************
|
||||
|
||||
|
|
@ -132,20 +143,6 @@ def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
|
|||
return inputs[b.arg] if b.op is Ops.PARAM else b
|
||||
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in call.src[1:] if b.op is not Ops.BIND]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, bufs:list[Buffer], var_vals:dict[str, int],
|
||||
outputs=(0,), inputs=(1,), first_run=False):
|
||||
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
|
||||
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
|
||||
timing: list[float|None] = [None]
|
||||
if DEBUG >= 2: st = time.perf_counter()
|
||||
yield timing
|
||||
if not ctx.do_update_stats: return
|
||||
if DEBUG >= 2 and timing[0] is None:
|
||||
Device[device].synchronize()
|
||||
timing[0] = time.perf_counter() - st
|
||||
update_stats(display_name, device, estimate_uop(call), var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run)
|
||||
|
||||
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
|
||||
bufs = [b.buffer for b in resolved]
|
||||
if not any(isinstance(b, MultiBuffer) for b in bufs): yield cast(list[Buffer], bufs), {}
|
||||
|
|
@ -178,21 +175,21 @@ def exec_copy(ctx:ExecContext, call, ast):
|
|||
def exec_kernel(ctx:ExecContext, call, ast):
|
||||
for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)):
|
||||
var_vals = {**ctx.var_vals, **device_vars}
|
||||
prg = get_runner(bufs[0].device, ast)
|
||||
prg_bufs = [bufs[i].ensure_allocated() for i in prg.p.globals]
|
||||
|
||||
with track_stats(ctx, call, prg.device, prg.display_name, prg_bufs, var_vals,
|
||||
outputs=tuple(prg.p.outs), inputs=tuple(prg.p.ins), first_run=prg.first_run) as timing:
|
||||
timing[0] = prg(prg_bufs, var_vals, wait=DEBUG >= 2)
|
||||
prg.first_run = False
|
||||
prg_bufs = [bufs[i].ensure_allocated() for i in ast.arg.globals]
|
||||
rt = get_runtime(device:=bufs[0].device, ast)
|
||||
global_size, local_size = ast.arg.launch_dims(var_vals)
|
||||
with track_stats(ctx, call, device, ast.arg.name, prg_bufs, var_vals, outputs=ast.arg.outs, inputs=ast.arg.ins) as tm:
|
||||
tm[0] = rt(*[b._buf for b in prg_bufs], global_size=global_size, local_size=local_size, vals=ast.arg.vals(var_vals), wait=DEBUG>=2)
|
||||
|
||||
def exec_validate(ctx:ExecContext, call, ast):
|
||||
import numpy as np
|
||||
for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)):
|
||||
cpu_bufs, dev_bufs = bufs[:len(bufs)//2], bufs[len(bufs)//2:]
|
||||
cpu_prg = get_runner("CPU", ast.src[0])
|
||||
cpu_prg([cpu_bufs[i].ensure_allocated() for i in cpu_prg.p.globals], {**ctx.var_vals, **device_vars}, wait=False)
|
||||
for i in cpu_prg.p.outs: np.testing.assert_allclose(dev_bufs[i].ensure_allocated().numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3)
|
||||
bufs, dev_bufs = bufs[:len(bufs)//2], bufs[len(bufs)//2:]
|
||||
var_vals = {**ctx.var_vals, **device_vars}
|
||||
cpu_rt = get_runtime("CPU", prg:=to_program(ast.src[0], Device["CPU"].renderer))
|
||||
global_size, local_size = prg.arg.launch_dims(var_vals)
|
||||
cpu_rt(*[bufs[i].ensure_allocated()._buf for i in prg.arg.globals], global_size=global_size, local_size=local_size, vals=prg.arg.vals(var_vals))
|
||||
for i in prg.arg.outs: np.testing.assert_allclose(dev_bufs[i].ensure_allocated().numpy(), bufs[i].numpy(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
def exec_encdec(ctx:ExecContext, call, ast):
|
||||
bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(call, ctx.input_uops)]
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ class CUDAGraph(MultiGraphRunner):
|
|||
self.nodes: list[tuple[Any, ...]] = [] # list of tuple(graph node, node params, c_args/context, is memcpy)
|
||||
self.graph = init_c_var(cuda.CUgraph, lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
|
||||
for (dev_idx, ast, bufs, device_vars), prg in zip(self.calls, self.progs):
|
||||
for (dev_idx, ast, bufs, device_vars), runtime in zip(self.calls, self.runtimes):
|
||||
if ast.op is Ops.PROGRAM:
|
||||
assert prg is not None
|
||||
global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars})
|
||||
assert runtime is not None
|
||||
global_size, local_size = ast.arg.launch_dims({v: 0 for v in self.vars})
|
||||
|
||||
c_deps, new_node = self.new_node([b.base for b in bufs], prg.p.outs)
|
||||
c_args, vargs = encode_args([b._buf for b in bufs], [device_vars.get(x.expr, 0) for x in prg.p.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(prg._prg.prg, *global_size, *local_size, 0,
|
||||
c_deps, new_node = self.new_node([b.base for b in bufs], ast.arg.outs)
|
||||
c_args, vargs = encode_args([b._buf for b in bufs], [device_vars.get(x.expr, 0) for x in ast.arg.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(runtime.prg, *global_size, *local_size, 0,
|
||||
ctypes.cast(0, ctypes.POINTER(ctypes.c_void_p)), vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(c_deps or []), ctypes.byref(kern_params)))
|
||||
|
||||
|
|
|
|||
|
|
@ -27,19 +27,19 @@ class HCQGraph(MultiGraphRunner):
|
|||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
|
||||
for prg in self.progs:
|
||||
if prg is None: continue
|
||||
kernargs_size[prg.dev] += round_up(prg._prg.kernargs_alloc_size, 16)
|
||||
for runtime in self.runtimes:
|
||||
if runtime is None: continue
|
||||
kernargs_size[runtime.dev] += round_up(runtime.kernargs_alloc_size, 16)
|
||||
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_args: dict[int, HCQArgsState] = {}
|
||||
|
||||
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
|
||||
for j, prg in enumerate(self.progs):
|
||||
if prg is None: continue
|
||||
argsbuf = self.kernargs_bufs[prg.dev].offset(kargs_alloc[prg.dev].alloc(prg._prg.kernargs_alloc_size, 16))
|
||||
self.ji_args[j] = prg._prg.fill_kernargs(self.hcq_bufs[j], prg.p.vars, argsbuf)
|
||||
for j, runtime in enumerate(self.runtimes):
|
||||
if runtime is None: continue
|
||||
argsbuf = self.kernargs_bufs[runtime.dev].offset(kargs_alloc[runtime.dev].alloc(runtime.kernargs_alloc_size, 16))
|
||||
self.ji_args[j] = runtime.fill_kernargs(self.hcq_bufs[j], self.calls[j][1].arg.vars, argsbuf)
|
||||
|
||||
# Schedule Dependencies.
|
||||
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
||||
|
|
@ -83,13 +83,13 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.input_replace_map: dict[HCQCompiled, set[tuple[int, int]]] = collections.defaultdict(set)
|
||||
self.device_vars: dict[HCQCompiled, dict[str, int]] = {}
|
||||
|
||||
for j, ((_, ast, bufs, device_vars), prg) in enumerate(zip(self.calls, self.progs)):
|
||||
for j, ((_, ast, bufs, device_vars), runtime) in enumerate(zip(self.calls, self.runtimes)):
|
||||
is_xfer = ast.op is Ops.COPY and hasattr(alc:=Device[bufs[0].device].allocator, '_transfer') and alc.supports_transfer \
|
||||
and bufs[0].device.split(":")[0] == bufs[1].device.split(":")[0]
|
||||
ji_devs = [cast(HCQCompiled, Device[b.device]) for b in bufs] if is_xfer else []
|
||||
is_rdma = len(ji_devs) > 0 and not any(d._is_cpu() for d in ji_devs) and len(set(d.peer_group for d in ji_devs)) > 1
|
||||
|
||||
if prg is not None: enqueue_dev: HCQCompiled = prg.dev
|
||||
if runtime is not None: enqueue_dev: HCQCompiled = runtime.dev
|
||||
else:
|
||||
# For copy ops prioritize enqeueuing on the src device, so reverse the buffers.
|
||||
for b in bufs[::-1]:
|
||||
|
|
@ -97,9 +97,9 @@ class HCQGraph(MultiGraphRunner):
|
|||
|
||||
# set any fixedvars on the device
|
||||
self.device_vars[enqueue_dev] = merge_dicts([self.device_vars.get(enqueue_dev, {}), device_vars])
|
||||
if prg is not None: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], prg.p.runtimevars])
|
||||
if runtime is not None: self.device_vars[enqueue_dev] = merge_dicts([self.device_vars[enqueue_dev], ast.arg.runtimevars])
|
||||
|
||||
if prg is not None:
|
||||
if runtime is not None:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
elif is_rdma:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
|
|
@ -125,10 +125,10 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.rdma_deps[j] = (peer_queue, peer_sync_signals + peer_opt_deps, peer_out_signal, j + 1)
|
||||
self.last_j[peer_queue] = j
|
||||
else:
|
||||
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs, prg.p.outs if prg is not None else [0], enqueue_queue,
|
||||
sync_signals, opt_deps, rdeps = self._resolve_deps(bufs, ast.arg.outs if runtime is not None else [0], enqueue_queue,
|
||||
enqueue_dev, out_signal, j, is_copy=is_xfer)
|
||||
|
||||
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if prg is not None else (j + 1))
|
||||
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if runtime is not None else (j + 1))
|
||||
|
||||
# Collect profile information if profiling is enabled.
|
||||
if PROFILE:
|
||||
|
|
@ -136,9 +136,9 @@ class HCQGraph(MultiGraphRunner):
|
|||
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=self.last_j[enqueue_queue]) is not None else j * 2
|
||||
|
||||
# Description based on the command.
|
||||
prof_ji_desc = prg._prg.name if prg is not None else TracingKey(f"{bufs[1].device} -> {bufs[0].device}", ret=bufs[0].nbytes) # type: ignore
|
||||
prof_ji_desc = runtime.name if runtime is not None else TracingKey(f"{bufs[1].device} -> {bufs[0].device}", ret=bufs[0].nbytes) # type: ignore
|
||||
|
||||
prof_name = enqueue_dev.device if prg is not None else f"{enqueue_dev.device}:SDMA:{queue_idx}"
|
||||
prof_name = enqueue_dev.device if runtime is not None else f"{enqueue_dev.device}:SDMA:{queue_idx}"
|
||||
self.prof_graph_entries.append(ProfileGraphEntry(prof_name, prof_ji_desc, sig_st, j * 2 + 1))
|
||||
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
|
||||
|
||||
|
|
@ -159,7 +159,7 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
||||
.wait(self.kick_signals[dev.peer_group], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
||||
|
||||
for j, ((dev_idx, ast, bufs, _), prg) in enumerate(zip(self.calls, self.progs)):
|
||||
for j, ((dev_idx, ast, bufs, _), runtime) in enumerate(zip(self.calls, self.runtimes)):
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
# Lazy allocate signals
|
||||
|
|
@ -171,8 +171,8 @@ class HCQGraph(MultiGraphRunner):
|
|||
if PROFILE and j * 2 in self.prof_signal_is_used: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if prg is not None:
|
||||
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1))) # type: ignore[arg-type]
|
||||
if runtime is not None:
|
||||
enqueue_queue.exec(runtime, self.ji_args[j], ast.arg.global_size or (1,1,1), ast.arg.local_size or (1,1,1)) # type: ignore[arg-type]
|
||||
elif j in self.rdma_deps:
|
||||
dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j]
|
||||
for sig, val in dest_deps: dest_queue.wait(sig, val)
|
||||
|
|
|
|||
|
|
@ -27,17 +27,17 @@ class MetalGraph(GraphRunner):
|
|||
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
|
||||
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.vars) else []
|
||||
for j, ((_, _, bufs, _), prg, replace) in enumerate(zip(self.calls, self.progs, self.uop_replace)):
|
||||
assert prg is not None
|
||||
for j, ((_, ast, bufs, _), runtime, replace) in enumerate(zip(self.calls, self.runtimes, self.uop_replace)):
|
||||
assert runtime is not None
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex(j).retained()
|
||||
icb_command.setComputePipelineState(prg._prg.pipeline_state)
|
||||
all_pipelines.append(prg._prg.pipeline_state)
|
||||
icb_command.setComputePipelineState(runtime.pipeline_state)
|
||||
all_pipelines.append(runtime.pipeline_state)
|
||||
for i, b in enumerate(bufs):
|
||||
if not any(pos == i for pos, _ in replace):
|
||||
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i, v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.vars.index(v.expr)*4, len(bufs)+i)
|
||||
global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars})
|
||||
for i, v in enumerate(ast.arg.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.vars.index(v.expr)*4, len(bufs)+i)
|
||||
global_size, local_size = ast.arg.launch_dims({v: 0 for v in self.vars})
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ class MetalGraph(GraphRunner):
|
|||
def collect_timestamps(self):
|
||||
# create a graph event and evenly space each program
|
||||
st, en = decimal.Decimal(self.command_buffer.GPUStartTime()) * 1000000, decimal.Decimal(self.command_buffer.GPUEndTime()) * 1000000
|
||||
ents = [ProfileGraphEntry(self.device, prg._prg.name, i, i+1) for i, prg in enumerate(self.progs) if prg is not None]
|
||||
ents = [ProfileGraphEntry(self.device, rt.name, i, i+1) for i, rt in enumerate(self.runtimes) if rt is not None]
|
||||
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st + (en-st)/len(ents)*i for i in range(len(ents)+1)])]
|
||||
|
||||
def __del__(self):
|
||||
|
|
|
|||
|
|
@ -983,11 +983,13 @@ class ProgramInfo:
|
|||
@property
|
||||
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
|
||||
|
||||
def launch_dims(self, var_vals:dict[str, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] # type: ignore[arg-type]
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
def launch_dims(self, var_vals:dict[str, int]) -> tuple[tuple[int, ...], tuple[int, ...]|None]:
|
||||
global_size = tuple([sym_infer(sz, var_vals) for sz in self.global_size]) # type: ignore[arg-type]
|
||||
local_size = tuple([sym_infer(sz, var_vals) for sz in self.local_size]) if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
def vals(self, var_vals:dict[str, int]): return tuple(var_vals[k.expr] if k.expr not in self.runtimevars else None for k in self.vars)
|
||||
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp, aux:tuple=()) -> ProgramInfo:
|
||||
_vars: list[UOp] = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue