hcq exec no embedded signal (#5142)

This commit is contained in:
nimlgen 2024-07-04 13:29:21 +03:00 committed by GitHub
commit 84b3e3bb6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 29 additions and 22 deletions

View file

@ -44,7 +44,7 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)]
self.signal_sched: Dict[int, Tuple[List, Any, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)]
self.signals = {q: self.devices[0]._alloc_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
self.dev_kickoff_signal = {dev: self.devices[0]._alloc_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
self.kickoff_value = 0
@ -74,11 +74,11 @@ class HCQGraph(MultiGraphRunner):
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
for sig, val in deps:
if id(sig) in [id(x) for x in self.signals.values()]:
self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:]
self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
prof_info = ([enqueue_dev._alloc_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
self.last_ji[enqueue_queue] = j
# Build hardware queues.
@ -91,7 +91,7 @@ class HCQGraph(MultiGraphRunner):
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
for j,ji in enumerate(self.jit_cache):
deps, signal_value, prof_info = self.signal_sched[j]
deps, signal, signal_val, prof_info = self.signal_sched[j]
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
# Encode waits and start profile timestamp (if needed).
@ -102,22 +102,23 @@ class HCQGraph(MultiGraphRunner):
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals))
self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
Device[src.device]._gpu_map(dest._buf) #type: ignore
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value)
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
# Encode finish profile timestamp (if needed).
if prof_info: enqueue_queue.timestamp(prof_info[1])
for dev in self.devices:
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1])
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][2])
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
@ -132,7 +133,7 @@ class HCQGraph(MultiGraphRunner):
self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
if PROFILE and self.kickoff_value > 1:
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
for _,_,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
# Update rawbuffers
@ -181,7 +182,7 @@ class HCQGraph(MultiGraphRunner):
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
if PROFILE and self.kickoff_value > 1:
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
for _,_,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
map(self.devices[0]._free_signal, list(self.dev_kickoff_signal.values()) + list(self.signals.values()))
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))

View file

@ -104,7 +104,7 @@ class HWPM4Queue(HWQueue):
self._invalidate_cache()
return self._mark_command_end()
def exec(self, prg, kernargs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), signal=None, signal_value=0):
def exec(self, prg, kernargs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1)):
self._invalidate_cache()
user_data = [*data64_le(kernargs)]
@ -130,7 +130,6 @@ class HWPM4Queue(HWQueue):
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_DISPATCH_DIRECT, 3), *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN]
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_EVENT_WRITE, 0), amd_gpu.EVENT_TYPE(7) | amd_gpu.EVENT_INDEX(4)]
if signal is not None: self.signal(signal, signal_value)
return self._mark_command_end()
def update_exec(self, cmd_idx, global_size, local_size):

View file

@ -146,8 +146,8 @@ class HWQueue:
class HWComputeQueue(HWQueue):
def __init__(self):
super().__init__()
self.cmd_idx_to_qmd, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}
super().__init__()
def copy_from_cpu(self, gpuaddr, data):
self.q += [nvmethod(1, nv_gpu.NVC6C0_OFFSET_OUT_UPPER, 2), *nvdata64(gpuaddr)]
@ -156,7 +156,7 @@ class HWComputeQueue(HWQueue):
self.q += [nvmethod(1, nv_gpu.NVC6C0_LOAD_INLINE_DATA, len(data), typ=6)] + list(data)
return self._mark_command_end()
def exec(self, prg, kernargs, global_size=(1,1,1), local_size=(1,1,1), signal=None, signal_value=0):
def exec(self, prg, kernargs, global_size=(1,1,1), local_size=(1,1,1)):
ctypes.memmove(qmd_addr:=(kernargs + round_up(prg.constbuf_0_size, 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4)
self.cmd_idx_to_qmd[len(self)] = qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
self.cmd_idx_to_global_dims[len(self)] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, 12).cast('I')
@ -164,14 +164,7 @@ class HWComputeQueue(HWQueue):
qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth = global_size
qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2 = local_size
qmd.constant_buffer_addr_lower_0 = kernargs & 0xffffffff
qmd.constant_buffer_addr_upper_0 = kernargs >> 32
if signal is not None:
qmd.release0_address_lower = ctypes.addressof(from_mv(signal)) & 0xffffffff
qmd.release0_address_upper = ctypes.addressof(from_mv(signal)) >> 32
qmd.release0_payload_lower = signal_value & 0xffffffff
qmd.release0_payload_upper = signal_value >> 32
qmd.release0_enable = 1
qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = nvdata64(kernargs)
if (prev_qmd:=self.cmd_idx_to_qmd.get(len(self) - 1)) is None:
self.q += [nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0)]
@ -189,6 +182,20 @@ class HWComputeQueue(HWQueue):
self.cmd_idx_to_global_dims[cmd_idx][:] = array.array('I', global_size)
self.cmd_idx_to_local_dims[cmd_idx][:] = array.array('H', local_size)
def signal(self, signal, value=0):
if (prev_qmd:=self.cmd_idx_to_qmd.get(len(self) - 1)) is None or prev_qmd.release0_enable == 1: return super().signal(signal, value)
prev_qmd.release0_address_upper, prev_qmd.release0_address_lower = nvdata64(ctypes.addressof(from_mv(signal)))
prev_qmd.release0_payload_upper, prev_qmd.release0_payload_lower = nvdata64(value)
prev_qmd.release0_enable = 1
self.cmd_idx_to_qmd[len(self)] = prev_qmd # this command is embedded into qmd.
return self._mark_command_end()
def update_signal(self, cmd_idx, signal=None, value=None):
if (qmd:=self.cmd_idx_to_qmd.get(cmd_idx)) is None: return super().update_signal(cmd_idx, signal, value)
if signal is not None: qmd.release0_address_upper, qmd.release0_address_lower = nvdata64(ctypes.addressof(from_mv(signal)))
if value is not None: qmd.release0_payload_upper, qmd.release0_payload_lower = nvdata64(value)
return self
def submit(self, dev:NVDevice): self._submit(dev, dev.compute_gpfifo)
class HWCopyQueue(HWQueue):