mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
8baca185d5
commit
95d63d6c07
2 changed files with 173 additions and 187 deletions
|
|
@ -138,35 +138,28 @@ def unwrap_after(uop):
|
|||
while uop.op is Ops.AFTER: uop = uop.src[0]
|
||||
return uop
|
||||
|
||||
def make_getaddr(u, dev=None):
|
||||
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=dev or to_tuple(u.device)[0])))
|
||||
|
||||
def make_ins(op, *srcs):
|
||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||
|
||||
def make_cmdbuf(lin, devs, tag):
|
||||
blob, patches = b'', []
|
||||
for s in (s for ins in lin.src for s in ins.src):
|
||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||
buf = UOp.new_buffer(devs if len(devs) > 1 else devs[0], len(blob), dtypes.uint8).rtag(tag)
|
||||
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||
|
||||
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
||||
|
||||
def make_signal(devs, queue=None, sentinel=False):
|
||||
return UOp.new_buffer(devs, 1, dtypes.uint64).rtag("sentinel_signal" if sentinel else (queue, "timeline_signal") if queue else "timeline_signal")
|
||||
def make_signal_value(devs, queue=None): return UOp.new_buffer(devs, 1, dtypes.uint64).rtag((queue, "timeline_value") if queue else "timeline_value")
|
||||
|
||||
class HCQEncoder:
|
||||
def __init__(self): self.blob, self.patches = b'', []
|
||||
|
||||
def get_dev_addr(self, uop:UOp) -> UOp:
|
||||
if unwrap_after(uop).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return uop
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(uop, UOp(Ops.DEVICE, arg=self.dev.device)))
|
||||
|
||||
def append(self, *data, dtype=dtypes.uint32):
|
||||
for d in data:
|
||||
if isinstance(d, int): self.blob += struct.pack(f'<{dtype.fmt}', d)
|
||||
else:
|
||||
self.patches.append((len(self.blob), self.get_dev_addr(d), dtype))
|
||||
self.blob += struct.pack(f'<{dtype.fmt}', 0)
|
||||
|
||||
def q(self, *values): self.append(*values)
|
||||
|
||||
def uop(self, dev:str|tuple[str, ...], tag:str|None=None) -> UOp:
|
||||
buf = UOp.new_buffer(dev, len(self.blob), dtypes.uint8)
|
||||
if tag: buf = buf.rtag(tag)
|
||||
blob_uop = UOp(Ops.BINARY, dtypes.void, src=(), arg=self.blob)
|
||||
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for off, val, dt in self.patches]
|
||||
return buf.after(buf.store(blob_uop), *stores)
|
||||
|
||||
# *****************
|
||||
# 0. helpers
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import cast
|
|||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, HCQEncoder, to_tuple
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, to_tuple, make_getaddr, make_ins, make_cmdbuf
|
||||
from tinygrad.uop.ops import sint, UOp
|
||||
from tinygrad.device import Compiled, BufferSpec, Buffer, Device
|
||||
from tinygrad.dtype import dtypes
|
||||
|
|
@ -24,131 +24,124 @@ from tinygrad.runtime.ops_amd import SQTT, SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, S
|
|||
from tinygrad.runtime.ops_amd import EVENT_INDEX_PARTIAL_FLUSH, WAIT_REG_MEM_FUNCTION_EQ, WAIT_REG_MEM_FUNCTION_NEQ, WAIT_REG_MEM_FUNCTION_GEQ
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
from tinygrad.engine.realize import get_runtime
|
||||
from tinygrad.engine.realize import get_runtime, pm_flatten_linear
|
||||
from tinygrad.uop import FastEnum, auto
|
||||
from tinygrad.uop.ops import Ops, UPat, PatternMatcher, graph_rewrite
|
||||
|
||||
class AMDComputeQueue(HCQEncoder):
|
||||
def __init__(self, dev:AMDDevice, devs:tuple[str, ...]|None=None):
|
||||
super().__init__()
|
||||
self.dev, self.devs = dev, devs or (dev.device,)
|
||||
self.pm4, self.gc, self.nbio, self.soc = dev.pm4, dev.gc, dev.nbio, dev.soc
|
||||
# *****************
|
||||
# PM4
|
||||
|
||||
def pkt3(self, cmd, *vals): self.q(self.pm4.PACKET3(cmd, len(vals) - 1), *vals)
|
||||
class PM4Ops(FastEnum):
|
||||
SET_SH_REG = auto(); SET_UCONFIG_REG = auto(); WAIT_REG_MEM = auto(); ACQUIRE_MEM = auto() # noqa: E702
|
||||
RELEASE_MEM = auto(); DISPATCH_DIRECT = auto(); EVENT_WRITE = auto() # noqa: E702
|
||||
|
||||
def wreg(self, reg:AMDReg, *args:sint, **kwargs:int):
|
||||
if bool(args) == bool(kwargs): raise RuntimeError('One (and only one) of *args or **kwargs must be specified')
|
||||
if self.pm4.PACKET3_SET_SH_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_SH_REG_END:
|
||||
set_packet, set_packet_start = self.pm4.PACKET3_SET_SH_REG, self.pm4.PACKET3_SET_SH_REG_START
|
||||
elif self.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
|
||||
set_packet, set_packet_start = self.pm4.PACKET3_SET_UCONFIG_REG, self.pm4.PACKET3_SET_UCONFIG_REG_START
|
||||
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr[0]}) via pm4 packet')
|
||||
self.pkt3(set_packet, reg.addr[0] - set_packet_start, *(args or (reg.encode(**kwargs),)))
|
||||
def pkt3(ctx, op:PM4Ops, *vals): return make_ins(op, ctx.pm4.PACKET3(getattr(ctx.pm4, f"PACKET3_{op.name}"), len(vals) - 1), *vals)
|
||||
|
||||
def wait_reg_mem(self, value, mask=0xffffffff, mem=None, reg=None, reg_done=0, op=WAIT_REG_MEM_FUNCTION_GEQ):
|
||||
wrm_info_dw = self.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | self.pm4.WAIT_REG_MEM_OPERATION(int(mem is None and reg_done > 0)) \
|
||||
| self.pm4.WAIT_REG_MEM_FUNCTION(op) | self.pm4.WAIT_REG_MEM_ENGINE(0)
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, wrm_info_dw, *(data64_le(mem) if mem is not None else (reg, reg_done)), value, mask, 4)
|
||||
def wreg(ctx, reg:AMDReg, *args:sint, **kwargs:int):
|
||||
if bool(args) == bool(kwargs): raise RuntimeError('One (and only one) of *args or **kwargs must be specified')
|
||||
if ctx.pm4.PACKET3_SET_SH_REG_START <= reg.addr[0] < ctx.pm4.PACKET3_SET_SH_REG_END:
|
||||
op, set_packet_start = PM4Ops.SET_SH_REG, ctx.pm4.PACKET3_SET_SH_REG_START
|
||||
elif ctx.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr[0] < ctx.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
|
||||
op, set_packet_start = PM4Ops.SET_UCONFIG_REG, ctx.pm4.PACKET3_SET_UCONFIG_REG_START
|
||||
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr[0]}) via pm4 packet')
|
||||
return pkt3(ctx, op, reg.addr[0] - set_packet_start, *(args or (reg.encode(**kwargs),)))
|
||||
|
||||
def acquire_mem(self, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1):
|
||||
if self.dev.target[0] != 9:
|
||||
cache_flags_dw = self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLI_INV(gli) \
|
||||
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_INV(glm) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_WB(glm) \
|
||||
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_INV(glk) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_WB(glk) \
|
||||
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLV_INV(glv) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL1_INV(gl1) \
|
||||
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_INV(gl2) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_WB(gl2)
|
||||
self.pkt3(self.pm4.PACKET3_ACQUIRE_MEM, 0, *data64_le(sz), *data64_le(addr), 0, cache_flags_dw)
|
||||
else:
|
||||
cp_coher_cntl = self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_SH_ICACHE_ACTION_ENA(gli) | \
|
||||
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_SH_KCACHE_ACTION_ENA(glk) | \
|
||||
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TC_ACTION_ENA(gl2) | \
|
||||
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TCL1_ACTION_ENA(gl1) | \
|
||||
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TC_WB_ACTION_ENA(gl2)
|
||||
self.pkt3(self.pm4.PACKET3_ACQUIRE_MEM, cp_coher_cntl, *data64_le(sz), *data64_le(addr), 0x0000000A)
|
||||
def wait_reg_mem(ctx, value, mask=0xffffffff, mem=None, reg=None, reg_done=0, op=WAIT_REG_MEM_FUNCTION_GEQ):
|
||||
wrm_info_dw = ctx.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | ctx.pm4.WAIT_REG_MEM_OPERATION(int(mem is None and reg_done > 0)) \
|
||||
| ctx.pm4.WAIT_REG_MEM_FUNCTION(op) | ctx.pm4.WAIT_REG_MEM_ENGINE(0)
|
||||
return pkt3(ctx, PM4Ops.WAIT_REG_MEM, wrm_info_dw, *(data64_le(mem) if mem is not None else (reg, reg_done)), value, mask, 4)
|
||||
|
||||
def release_mem(self, address=0x0, value=0, data_sel=0, int_sel=2, ctxid=0, cache_flush=False):
|
||||
if self.dev.target[0] != 9:
|
||||
cache_flags_dw = 0 if not cache_flush else (self.pm4.PACKET3_RELEASE_MEM_GCR_GLV_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL1_INV \
|
||||
| self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_WB \
|
||||
| self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_WB | self.pm4.PACKET3_RELEASE_MEM_GCR_SEQ)
|
||||
event_dw = self.pm4.PACKET3_RELEASE_MEM_EVENT_TYPE(self.pm4.CACHE_FLUSH_AND_INV_TS_EVENT) \
|
||||
| self.pm4.PACKET3_RELEASE_MEM_EVENT_INDEX(self.pm4.event_index__mec_release_mem__end_of_pipe)
|
||||
memsel_dw = self.pm4.PACKET3_RELEASE_MEM_DATA_SEL(data_sel) | self.pm4.PACKET3_RELEASE_MEM_INT_SEL(int_sel) \
|
||||
| self.pm4.PACKET3_RELEASE_MEM_DST_SEL(0)
|
||||
else:
|
||||
cache_flags_dw = 0 if not cache_flush else (self.pm4.EOP_TC_WB_ACTION_EN | self.pm4.EOP_TC_NC_ACTION_EN)
|
||||
event_dw = self.pm4.EVENT_TYPE(self.pm4.CACHE_FLUSH_AND_INV_TS_EVENT) | self.pm4.EVENT_INDEX(self.pm4.event_index__mec_release_mem__end_of_pipe)
|
||||
memsel_dw = self.pm4.DATA_SEL(data_sel) | self.pm4.INT_SEL(int_sel)
|
||||
ctxid = 0
|
||||
self.pkt3(self.pm4.PACKET3_RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid)
|
||||
def acquire_mem(ctx, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1):
|
||||
if ctx.target[0] != 9:
|
||||
cache_flags_dw = ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLI_INV(gli) \
|
||||
| ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_INV(glm) | ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_WB(glm) \
|
||||
| ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_INV(glk) | ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_WB(glk) \
|
||||
| ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLV_INV(glv) | ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL1_INV(gl1) \
|
||||
| ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_INV(gl2) | ctx.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_WB(gl2)
|
||||
return pkt3(ctx, PM4Ops.ACQUIRE_MEM, 0, *data64_le(sz), *data64_le(addr), 0, cache_flags_dw)
|
||||
cp_coher_cntl = ctx.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_SH_ICACHE_ACTION_ENA(gli) | \
|
||||
ctx.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_SH_KCACHE_ACTION_ENA(glk) | \
|
||||
ctx.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TC_ACTION_ENA(gl2) | \
|
||||
ctx.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TCL1_ACTION_ENA(gl1) | \
|
||||
ctx.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TC_WB_ACTION_ENA(gl2)
|
||||
return pkt3(ctx, PM4Ops.ACQUIRE_MEM, cp_coher_cntl, *data64_le(sz), *data64_le(addr), 0x0000000A)
|
||||
|
||||
def memory_barrier(self):
|
||||
pf = '' if self.nbio.version[0] == 2 else '0' if self.nbio.version[:2] != (7, 11) else '1'
|
||||
self.wait_reg_mem(reg=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr[0],
|
||||
reg_done=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff)
|
||||
self.acquire_mem()
|
||||
def release_mem(ctx, address=0x0, value=0, data_sel=0, int_sel=2, ctxid=0, cache_flush=False):
|
||||
if ctx.target[0] != 9:
|
||||
cache_flags_dw = 0 if not cache_flush else (ctx.pm4.PACKET3_RELEASE_MEM_GCR_GLV_INV | ctx.pm4.PACKET3_RELEASE_MEM_GCR_GL1_INV \
|
||||
| ctx.pm4.PACKET3_RELEASE_MEM_GCR_GL2_INV | ctx.pm4.PACKET3_RELEASE_MEM_GCR_GLM_WB \
|
||||
| ctx.pm4.PACKET3_RELEASE_MEM_GCR_GLM_INV | ctx.pm4.PACKET3_RELEASE_MEM_GCR_GL2_WB | ctx.pm4.PACKET3_RELEASE_MEM_GCR_SEQ)
|
||||
event_dw = ctx.pm4.PACKET3_RELEASE_MEM_EVENT_TYPE(ctx.pm4.CACHE_FLUSH_AND_INV_TS_EVENT) \
|
||||
| ctx.pm4.PACKET3_RELEASE_MEM_EVENT_INDEX(ctx.pm4.event_index__mec_release_mem__end_of_pipe)
|
||||
memsel_dw = ctx.pm4.PACKET3_RELEASE_MEM_DATA_SEL(data_sel) | ctx.pm4.PACKET3_RELEASE_MEM_INT_SEL(int_sel) \
|
||||
| ctx.pm4.PACKET3_RELEASE_MEM_DST_SEL(0)
|
||||
else:
|
||||
cache_flags_dw = 0 if not cache_flush else (ctx.pm4.EOP_TC_WB_ACTION_EN | ctx.pm4.EOP_TC_NC_ACTION_EN)
|
||||
event_dw = ctx.pm4.EVENT_TYPE(ctx.pm4.CACHE_FLUSH_AND_INV_TS_EVENT) | ctx.pm4.EVENT_INDEX(ctx.pm4.event_index__mec_release_mem__end_of_pipe)
|
||||
memsel_dw = ctx.pm4.DATA_SEL(data_sel) | ctx.pm4.INT_SEL(int_sel)
|
||||
ctxid = 0
|
||||
return pkt3(ctx, PM4Ops.RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid)
|
||||
|
||||
def wait(self, x): self.wait_reg_mem(x.src[1], mem=self.get_dev_addr(x.src[0]))
|
||||
def memory_barrier(ctx):
|
||||
pf = '' if ctx.nbio.version[0] == 2 else '0' if ctx.nbio.version[:2] != (7, 11) else '1'
|
||||
return UOp(Ops.LINEAR, dtypes.void, (
|
||||
wait_reg_mem(ctx, reg=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr[0],
|
||||
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
|
||||
acquire_mem(ctx)))
|
||||
|
||||
def barrier(self, x): self.memory_barrier()
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
|
||||
|
||||
def store(self, x):
|
||||
self.release_mem(self.get_dev_addr(x.src[0]), x.src[1], self.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
self.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
def pm4_barrier(ctx): return memory_barrier(ctx)
|
||||
|
||||
def timestamp(self, x):
|
||||
self.release_mem(self.get_dev_addr(x.src[0]), 0, self.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
self.pm4.int_sel__mec_release_mem__none)
|
||||
def pm4_store(ctx, dst, val):
|
||||
if val.op is Ops.BINARY: return None
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
|
||||
def program(self, x):
|
||||
data, info = x.arg
|
||||
lib_gpu, args = x.src
|
||||
prog_addr = self.get_dev_addr(lib_gpu) + data.entry_point_offset
|
||||
def pm4_timestamp(ctx, dst):
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
ctx.pm4.int_sel__mec_release_mem__none)
|
||||
|
||||
self.acquire_mem(gli=0, gl2=0)
|
||||
def pm4_program(ctx, prg):
|
||||
data, info = prg.arg
|
||||
lib_gpu, args = prg.src
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
|
||||
args_addr = make_getaddr(args, ctx.device)
|
||||
|
||||
scratch_addr = self.get_dev_addr(UOp.new_buffer(self.devs, data.private_segment_size, dtypes.uint8).rtag("scratch"))
|
||||
args_addr = self.get_dev_addr(args)
|
||||
user_regs = []
|
||||
if data.enable_private_segment_sgpr:
|
||||
scratch_hilo = data64_le(scratch_addr)
|
||||
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
|
||||
if data.enable_dispatch_ptr: user_regs += [*data64_le(args_addr + data.kernargs_segment_size)]
|
||||
user_regs += [*data64_le(args_addr)]
|
||||
|
||||
user_regs = []
|
||||
if data.enable_private_segment_sgpr:
|
||||
scratch_hilo = data64_le(scratch_addr)
|
||||
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
|
||||
if data.enable_dispatch_ptr: user_regs += [*data64_le(args_addr + data.kernargs_segment_size)]
|
||||
user_regs += [*data64_le(args_addr)]
|
||||
dispatch_init = ctx.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(
|
||||
**({'cs_w32_en': int(data.wave32)} if ctx.target[0] != 9 else {}), force_start_at_000=1, compute_shader_en=1)
|
||||
ins = [acquire_mem(ctx, gli=0, gl2=0),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_PGM_LO, *data64_le(prog_addr >> 8)),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_PGM_RSRC1, data.rsrc1, data.rsrc2),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_PGM_RSRC3, data.rsrc3),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_TMPRING_SIZE, ctx.tmpring_size(data.private_segment_size))]
|
||||
ins += [wreg(ctx, ctx.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le((scratch_addr + data.private_segment_size // ctx.xccs * xcc_id) >> 8))
|
||||
for xcc_id in range(ctx.xccs)]
|
||||
ins += [wreg(ctx, ctx.gc.regCOMPUTE_RESTART_X, 0, 0, 0),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_USER_DATA_0, *user_regs),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_RESOURCE_LIMITS, ctx.gc.regCOMPUTE_RESOURCE_LIMITS.encode(waves_per_sh=getenv("WAVES_PER_SH"))),
|
||||
wreg(ctx, ctx.gc.regCOMPUTE_START_X, 0, 0, 0, *(info.local_size or (1, 1, 1)), 0, 0),
|
||||
pkt3(ctx, PM4Ops.DISPATCH_DIRECT, *info.global_size, dispatch_init),
|
||||
pkt3(ctx, PM4Ops.EVENT_WRITE, ctx.pm4.EVENT_TYPE(ctx.soc.CS_PARTIAL_FLUSH) | ctx.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))]
|
||||
return UOp(Ops.LINEAR, dtypes.void, tuple(ins))
|
||||
|
||||
self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prog_addr >> 8))
|
||||
self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, data.rsrc1, data.rsrc2)
|
||||
self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, data.rsrc3)
|
||||
self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, self.dev.tmpring_size(data.private_segment_size))
|
||||
|
||||
for xcc_id in range(self.dev.xccs):
|
||||
scratch_base = scratch_addr + (data.private_segment_size // self.dev.xccs * xcc_id)
|
||||
self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
|
||||
|
||||
self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
|
||||
self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
|
||||
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, self.gc.regCOMPUTE_RESOURCE_LIMITS.encode(waves_per_sh=getenv("WAVES_PER_SH")))
|
||||
self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *(info.local_size or (1, 1, 1)), 0, 0)
|
||||
|
||||
dispatch_init = self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(
|
||||
**({'cs_w32_en': int(data.wave32)} if self.dev.target[0] != 9 else {}), force_start_at_000=1, compute_shader_en=1)
|
||||
self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *info.global_size, dispatch_init)
|
||||
self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
|
||||
|
||||
amd_inner_pm = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.WAIT, name="x"),)), lambda ctx, x: ctx.wait(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.BARRIER, name="x"),)), lambda ctx, x: ctx.barrier(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.PROGRAM, name="x"),)), lambda ctx, x: ctx.program(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.CUSTOM_FUNCTION, arg="timestamp", name="x"),)), lambda ctx, x: ctx.timestamp(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.STORE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat()), name="x"),)), lambda ctx, x: ctx.store(x)),
|
||||
pm_pm4_opsel = PatternMatcher([
|
||||
(UPat(Ops.WAIT, src=(UPat(name="dst"), UPat(name="val"))), pm4_wait),
|
||||
(UPat(Ops.BARRIER), pm4_barrier),
|
||||
(UPat(Ops.PROGRAM, name="prg"), pm4_program),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="timestamp", src=(UPat(name="dst"),)), pm4_timestamp),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.BUFFER, Ops.PARAM), name="dst"), UPat(name="val"))), pm4_store),
|
||||
])
|
||||
|
||||
def amd_lower_pm4(linear, devs):
|
||||
enc = AMDComputeQueue(Device[devs[0]], devs)
|
||||
graph_rewrite(linear.replace(src=tuple(UOp(Ops.LINEAR, dtypes.void, (cmd,)) for cmd in linear.src)), amd_inner_pm, ctx=enc, name="amd: encode")
|
||||
return enc.uop(dev=devs if len(devs) > 1 else devs[0], tag="compute")
|
||||
|
||||
def amd_submit_pm4(cmdbuf, devs):
|
||||
def pm4_submit(cmdbuf, devs):
|
||||
size, zero = UOp.const(dtypes.uint32, cmdbuf.src[0].arg // dtypes.uint32.itemsize), UOp.const(dtypes.int, 0)
|
||||
|
||||
# the compute queue's ring and its host-side ring/write/put pointers (placeholders, resolved in pm_bufferize)
|
||||
|
|
@ -172,49 +165,45 @@ def amd_submit_pm4(cmdbuf, devs):
|
|||
flush = UOp.barrier(copy_to_ring, bump_put_ptr, bump_wptr)
|
||||
return doorbell.after(flush).index(zero, dtype=doorbell.dtype.ptr()).store(next_put)
|
||||
|
||||
class AMDCopyQueue(HCQEncoder):
|
||||
def __init__(self, dev:AMDDevice, queue_idx=0):
|
||||
super().__init__()
|
||||
self.dev = dev
|
||||
self.sdma, self.queue_idx, self.max_copy_size = dev.sdma, queue_idx, dev.max_copy_size
|
||||
pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
|
||||
lambda lin: pm4_submit(make_cmdbuf(lin, to_tuple(lin.arg[0]), "compute"), to_tuple(lin.arg[0])))])
|
||||
|
||||
def copy(self, x):
|
||||
dest, src, copy_size = self.get_dev_addr(x.src[0]), self.get_dev_addr(x.src[1]), x.arg
|
||||
copied = 0
|
||||
while copied < copy_size:
|
||||
step = min(copy_size - copied, self.max_copy_size)
|
||||
self.q(self.sdma.SDMA_OP_COPY | self.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(self.sdma.SDMA_SUBOP_COPY_LINEAR),
|
||||
self.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(step - 1), 0, *data64_le(src + copied), *data64_le(dest + copied))
|
||||
copied += step
|
||||
# *****************
|
||||
# SDMA
|
||||
|
||||
def wait(self, x):
|
||||
self.q(self.sdma.SDMA_OP_POLL_REGMEM | self.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \
|
||||
self.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(self.get_dev_addr(x.src[0])), x.src[1], 0xffffffff,
|
||||
self.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | self.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
|
||||
|
||||
def store(self, x):
|
||||
fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target[0] != 9 else 0
|
||||
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(self.get_dev_addr(x.src[0])), x.src[1])
|
||||
self.q(self.sdma.SDMA_OP_TRAP, 0)
|
||||
def sdma_copy(ctx, dst, src, copy):
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
|
||||
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
|
||||
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
|
||||
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
|
||||
*data64_le(src_addr + off), *data64_le(dst_addr + off)) for off in range(0, copy.arg, ctx.max_copy_size)]))
|
||||
|
||||
def timestamp(self, x):
|
||||
self.q(self.sdma.SDMA_OP_TIMESTAMP | self.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(self.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
|
||||
*data64_le(self.get_dev_addr(x.src[0])))
|
||||
def sdma_wait(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
|
||||
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
|
||||
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
|
||||
def amd_lower_sdma(linear, devs):
|
||||
enc = AMDCopyQueue(Device[devs[0]])
|
||||
graph_rewrite(linear.replace(src=tuple(UOp(Ops.LINEAR, dtypes.void, (cmd,)) for cmd in linear.src)), amd_inner_sdma_pm, ctx=enc, name="amd: encode sdma")
|
||||
return enc.uop(dev=devs if len(devs) > 1 else devs[0], tag="copy")
|
||||
def sdma_store(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
|
||||
return UOp(Ops.LINEAR, dtypes.void, (
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
|
||||
amd_inner_sdma_pm = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.WAIT, name="x"),)), lambda ctx, x: ctx.wait(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.BARRIER, name="x"),)), lambda ctx, x: None),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.COPY, name="x"),)), lambda ctx, x: ctx.copy(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.CUSTOM_FUNCTION, arg="timestamp", name="x"),)), lambda ctx, x: ctx.timestamp(x)),
|
||||
(UPat(Ops.LINEAR, src=(UPat(Ops.STORE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat()), name="x"),)), lambda ctx, x: ctx.store(x)),
|
||||
def sdma_timestamp(ctx, dst):
|
||||
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
|
||||
|
||||
pm_sdma_opsel = PatternMatcher([
|
||||
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
|
||||
(UPat(Ops.WAIT, src=(UPat(name="dst"), UPat(name="val"))), sdma_wait),
|
||||
(UPat(Ops.COPY, src=(UPat(name="dst"), UPat(name="src")), name="copy"), sdma_copy),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="timestamp", src=(UPat(name="dst"),)), sdma_timestamp),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.BUFFER, Ops.PARAM), name="dst"), UPat(name="val"))), sdma_store),
|
||||
])
|
||||
|
||||
def amd_submit_sdma(cmdbuf, devs):
|
||||
def sdma_submit(cmdbuf, devs):
|
||||
# the cmdbuf to submit + the patch writes that fill it
|
||||
size_dw, zero = cmdbuf.src[0].arg // dtypes.uint32.itemsize, UOp.const(dtypes.int, 0)
|
||||
|
||||
|
|
@ -246,6 +235,9 @@ def amd_submit_sdma(cmdbuf, devs):
|
|||
flush = UOp.barrier(zero_tail, copy_to_ring, bump_put_ptr, bump_wptr)
|
||||
return doorbell.after(flush).index(zero, dtype=doorbell.dtype.ptr()).store(next_put_b)
|
||||
|
||||
pm_sdma_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
|
||||
lambda lin: sdma_submit(make_cmdbuf(lin, to_tuple(lin.arg[0]), "copy"), to_tuple(lin.arg[0])))])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDProgramData:
|
||||
entry_point_offset:int; rsrc1:int; rsrc2:int; rsrc3:int; wave32:bool
|
||||
|
|
@ -527,7 +519,8 @@ def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface
|
|||
def encode_queue(q:UOp) -> UOp|None:
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
|
||||
devs = to_tuple(q.arg[0])
|
||||
return amd_submit_pm4(amd_lower_pm4(q, devs), devs) if q.arg[1].startswith("COMPUTE") else amd_submit_sdma(amd_lower_sdma(q, devs), devs)
|
||||
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
|
||||
|
||||
pm_lower = PatternMatcher([
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue