no predecode

This commit is contained in:
George Hotz 2026-02-13 03:26:12 +00:00
commit d89cb880b2
2 changed files with 24 additions and 34 deletions

View file

@ -709,7 +709,7 @@ jobs:
- name: Run RDNA4 emulator tests
run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20
- name: Run CDNA4 emulator tests
run: MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20
run: AMD_LLVM=1 MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20
testnvidia:
strategy:

View file

@ -1161,7 +1161,7 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
# Check if instruction matches any cached canonical pattern
for base, mask, size, runner in _canonical_runner_cache:
if inst_size == size and (inst_int & mask) == base: return runner, False
if inst_size == size and (inst_int & mask) == base: return runner
# Look up handler by type, falling back to base classes for _LIT variants
handler = _INST_HANDLERS.get(type(inst))
@ -1181,30 +1181,19 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
runner = get_runner('CPU', sink)
_canonical_runner_cache.append((base, mask, size, runner))
return runner, True
return runner
@functools.cache
def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]:
"""Decode program to {pc: (name, fxn, globals, runner)}."""
result: dict[int, tuple[str, Callable, list[int], Any]] = {}
i = 0
while i < len(data):
inst = decode_inst(data[i:], arch)
if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break
try:
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch)
if DEBUG >= 3:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>"
msg = f"[emu] PC={i}: {inst_str}"
print(colored(msg, 'green') if is_new else msg)
result[i] = (runner.p.function_name, runner._prg.fxn, runner.p.globals, runner)
except Exception as e:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__}>"
raise RuntimeError(f"[emu] Failed to compile PC={i} {inst_str}: {type(e).__name__}: {e}") from e
i += inst.size()
return result
def _decode_at(pc: int, arch: str) -> tuple[Callable, list[int]]:
"""Decode and compile instruction at absolute address pc. Returns (fxn, globals)."""
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
inst = decode_inst(inst_bytes, arch)
try:
runner = _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch)
except Exception as e:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__}>"
raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e
return runner._prg.fxn, runner.p.globals
# ═══════════════════════════════════════════════════════════════════════════════
# WAVE STATE
@ -1253,8 +1242,7 @@ class WaveState:
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch)
program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses
program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals)
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
total_threads = lx * ly * lz
@ -1304,13 +1292,15 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)]
for inst_count in range(1_000_000):
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
name, fxn, globals_list, _ = program[pc]
assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}"
assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0"
if DEBUG >= 6:
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch)
print(f"[emu] exec PC={pc:X}: {inst!r}")
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break
if pc not in program:
prev_len = len(_canonical_runner_cache)
program[pc] = _decode_at(pc, arch)
if DEBUG >= 3:
inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch)
msg = f"[emu] PC={pc - lib}: {inst!r}"
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
fxn, globals_list = program[pc]
fxn(*[c_bufs[g] for g in globals_list])
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
return 0