mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
no predecode
This commit is contained in:
parent
f2894675c0
commit
d89cb880b2
2 changed files with 24 additions and 34 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -709,7 +709,7 @@ jobs:
|
|||
- name: Run RDNA4 emulator tests
|
||||
run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20
|
||||
- name: Run CDNA4 emulator tests
|
||||
run: MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20
|
||||
run: AMD_LLVM=1 MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20
|
||||
|
||||
testnvidia:
|
||||
strategy:
|
||||
|
|
|
|||
|
|
@ -1161,7 +1161,7 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
|||
|
||||
# Check if instruction matches any cached canonical pattern
|
||||
for base, mask, size, runner in _canonical_runner_cache:
|
||||
if inst_size == size and (inst_int & mask) == base: return runner, False
|
||||
if inst_size == size and (inst_int & mask) == base: return runner
|
||||
|
||||
# Look up handler by type, falling back to base classes for _LIT variants
|
||||
handler = _INST_HANDLERS.get(type(inst))
|
||||
|
|
@ -1181,30 +1181,19 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
|||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
|
||||
runner = get_runner('CPU', sink)
|
||||
_canonical_runner_cache.append((base, mask, size, runner))
|
||||
return runner, True
|
||||
return runner
|
||||
|
||||
@functools.cache
|
||||
def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]:
|
||||
"""Decode program to {pc: (name, fxn, globals, runner)}."""
|
||||
result: dict[int, tuple[str, Callable, list[int], Any]] = {}
|
||||
i = 0
|
||||
while i < len(data):
|
||||
inst = decode_inst(data[i:], arch)
|
||||
if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break
|
||||
try:
|
||||
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch)
|
||||
if DEBUG >= 3:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>"
|
||||
msg = f"[emu] PC={i}: {inst_str}"
|
||||
print(colored(msg, 'green') if is_new else msg)
|
||||
result[i] = (runner.p.function_name, runner._prg.fxn, runner.p.globals, runner)
|
||||
except Exception as e:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||||
raise RuntimeError(f"[emu] Failed to compile PC={i} {inst_str}: {type(e).__name__}: {e}") from e
|
||||
i += inst.size()
|
||||
return result
|
||||
def _decode_at(pc: int, arch: str) -> tuple[Callable, list[int]]:
|
||||
"""Decode and compile instruction at absolute address pc. Returns (fxn, globals)."""
|
||||
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
|
||||
inst = decode_inst(inst_bytes, arch)
|
||||
try:
|
||||
runner = _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch)
|
||||
except Exception as e:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||||
raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e
|
||||
return runner._prg.fxn, runner.p.globals
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# WAVE STATE
|
||||
|
|
@ -1253,8 +1242,7 @@ class WaveState:
|
|||
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
|
||||
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
|
||||
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
|
||||
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch)
|
||||
program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses
|
||||
program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals)
|
||||
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
|
||||
total_threads = lx * ly * lz
|
||||
|
||||
|
|
@ -1304,13 +1292,15 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
|
||||
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)]
|
||||
for inst_count in range(1_000_000):
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
|
||||
name, fxn, globals_list, _ = program[pc]
|
||||
assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}"
|
||||
assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0"
|
||||
if DEBUG >= 6:
|
||||
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch)
|
||||
print(f"[emu] exec PC={pc:X}: {inst!r}")
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break
|
||||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
program[pc] = _decode_at(pc, arch)
|
||||
if DEBUG >= 3:
|
||||
inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch)
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||||
fxn, globals_list = program[pc]
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
|
||||
return 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue