assembly/amd: mypy+ruff passes (#14701)

* assembly/amd: mypy+ruff passes

* touchups
This commit is contained in:
George Hotz 2026-02-12 16:59:42 +08:00 committed by GitHub
commit d5fc3ea1ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 593 additions and 377 deletions

View file

@ -20,13 +20,13 @@ test_llvm.py tests asm/disasm on the LLVM tests, confirming it behaves the same
tinygrad's dtype tests should pass with and without LLVM. they run in about 12 seconds.
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_dtype_alu.py test/test_dtype.py`
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_dtype_alu.py test/test_dtype.py`
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/backend/test_dtype_alu.py test/backend/test_dtype.py`
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/backend/test_dtype_alu.py test/backend/test_dtype.py`
The ops tests also pass, but they are very slow, so you should run them one at a time.
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_ops.py`
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_ops.py`
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/backend/test_ops.py`
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/backend/test_ops.py`
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`.
While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware
@ -34,6 +34,6 @@ If a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's because an instru
You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
IMPORTANT: if a test is failing in the emulator, it's an instruction bug. Use DEBUG=7, get the instructions, and debug.
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~3000 lines.
Get line count with `cloc --by-file extra/assembly/amd/*.py`

View file

@ -1,7 +1,7 @@
# autogenerated from AMD ISA XML - do not edit
# ruff: noqa: F401,F403
from extra.assembly.amd.dsl import *
from extra.assembly.amd.autogen.cdna.enum import *
# ruff: noqa: E501,F401
from extra.assembly.amd.dsl import BitField, DPP, DPP16, EXEC, EXECZ, EXEC_HI, EXEC_LO, EnumBitField, FixedBitField, INV_2PI, Inst, LIT, M0, NULL, OFF, SBaseField, SCC, SDWA, SGPRField, SRC_LDS_DIRECT, SRsrcField, SSrcField, SrcField, VCC, VCCZ, VCC_HI, VCC_LO, VGPRField, s, src, ttmp, v
from extra.assembly.amd.autogen.cdna.enum import DSOp, FLATOp, GLOBALOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3PX2Op, VOP3SDOp, VOPCOp, HWREG
import functools
class DS(Inst):

View file

@ -1,6 +1,6 @@
# autogenerated from AMD ISA XML - do not edit
from extra.assembly.amd.autogen.common import Fmt, OpType
from extra.assembly.amd.autogen.cdna.enum import *
from extra.assembly.amd.autogen.cdna.enum import DSOp, FLATOp, GLOBALOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3PX2Op, VOP3SDOp, VOPCOp
# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}
OPERANDS = {

View file

@ -1,7 +1,7 @@
# autogenerated from AMD ISA XML - do not edit
# ruff: noqa: F401,F403
from extra.assembly.amd.dsl import *
from extra.assembly.amd.autogen.rdna3.enum import *
# ruff: noqa: E501,F401
from extra.assembly.amd.dsl import BitField, DPP, DPP16, EXEC, EXECZ, EXEC_HI, EXEC_LO, EnumBitField, FixedBitField, INV_2PI, Inst, LIT, M0, NULL, OFF, SBaseField, SCC, SDWA, SGPRField, SRC_LDS_DIRECT, SRsrcField, SSrcField, SrcField, VCC, VCCZ, VCC_HI, VCC_LO, VDSTYField, VGPRField, s, src, ttmp, v
from extra.assembly.amd.autogen.rdna3.enum import DSOp, EXPOp, FLATOp, GLOBALOp, LDSDIROp, MIMGOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp, HWREG, MSG
import functools
class DS(Inst):
@ -593,9 +593,6 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32)
global_load_lds_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_LDS_ADDTID_B32)
flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32)

View file

@ -1,6 +1,6 @@
# autogenerated from AMD ISA XML - do not edit
from extra.assembly.amd.autogen.common import Fmt, OpType
from extra.assembly.amd.autogen.rdna3.enum import *
from extra.assembly.amd.autogen.rdna3.enum import DSOp, EXPOp, FLATOp, GLOBALOp, LDSDIROp, MIMGOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp
# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}
OPERANDS = {

View file

@ -1,7 +1,7 @@
# autogenerated from AMD ISA XML - do not edit
# ruff: noqa: F401,F403
from extra.assembly.amd.dsl import *
from extra.assembly.amd.autogen.rdna4.enum import *
# ruff: noqa: E501,F401
from extra.assembly.amd.dsl import BitField, DPP, DPP16, EXEC, EXECZ, EXEC_HI, EXEC_LO, EnumBitField, FixedBitField, INV_2PI, Inst, LIT, M0, NULL, OFF, SBaseField, SCC, SDWA, SGPRField, SRC_LDS_DIRECT, SSrcField, SrcField, VCC, VCCZ, VCC_HI, VCC_LO, VDSTYField, VGPRField, s, src, ttmp, v
from extra.assembly.amd.autogen.rdna4.enum import DSOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VBUFFEROp, VDSDIROp, VEXPORTOp, VFLATOp, VGLOBALOp, VIMAGEOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp, VSAMPLEOp, VSCRATCHOp, HWREG, MSG
import functools
class DS(Inst):
@ -973,8 +973,6 @@ flat_load_d16_hi_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(VFLAT, VFLATOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(VFLAT, VFLATOp.GLOBAL_STORE_ADDTID_B32)
flat_atomic_swap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U32)

View file

@ -1,6 +1,6 @@
# autogenerated from AMD ISA XML - do not edit
from extra.assembly.amd.autogen.common import Fmt, OpType
from extra.assembly.amd.autogen.rdna4.enum import *
from extra.assembly.amd.autogen.rdna4.enum import DSOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VBUFFEROp, VDSDIROp, VEXPORTOp, VFLATOp, VGLOBALOp, VIMAGEOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp, VSAMPLEOp, VSCRATCHOp
# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}
OPERANDS = {

View file

@ -44,11 +44,15 @@ class Reg:
def fmt(self, sz=None, parens=False, upper=False) -> str:
o, sz = self.offset, sz or self.sz
l, r = ("[", "]") if parens or sz > 1 else ("", "") # brackets for multi-reg or when parens=True
if 256 <= o < 512: idx = o - 256; base = f"v{l}{idx}{r}" if sz == 1 else f"v[{idx}:{idx + sz - 1}]"
if 256 <= o < 512:
idx = o - 256
base = f"v{l}{idx}{r}" if sz == 1 else f"v[{idx}:{idx + sz - 1}]"
elif o < 106: base = f"s{l}{o}{r}" if sz == 1 else f"s[{o}:{o + sz - 1}]"
elif sz == 2 and o in self._PAIRS: base = self._PAIRS[o] if upper else self._PAIRS[o].lower()
elif o in self._NAMES: base = self._NAMES[o] if upper else self._NAMES[o].lower() # special regs (any sz)
elif 108 <= o < 124: idx = o - 108; base = f"ttmp{l}{idx}{r}" if sz == 1 else f"ttmp[{idx}:{idx + sz - 1}]"
elif 108 <= o < 124:
idx = o - 108
base = f"ttmp{l}{idx}{r}" if sz == 1 else f"ttmp[{idx}:{idx + sz - 1}]"
elif 128 <= o <= 192: base = str(o - 128) # inline int constants (0-64)
elif 193 <= o <= 208: base = str(-(o - 192)) # inline negative int constants (-1 to -16)
else: raise RuntimeError(f"unknown register: offset={o}, sz={sz}")
@ -151,7 +155,8 @@ class SrcField(BitField):
expected_size = self._valid_range[1] - self._valid_range[0] + 1
actual_size = 1 << (hi - lo + 1)
if actual_size != expected_size:
raise RuntimeError(f"{self.__class__.__name__}: field size {hi - lo + 1} bits ({actual_size}) doesn't match range {self._valid_range} ({expected_size})")
raise RuntimeError(f"{self.__class__.__name__}: field size {hi - lo + 1} bits ({actual_size}) "
f"doesn't match range {self._valid_range} ({expected_size})")
def encode(self, val) -> int:
"""Encode value. Returns 255 (literal marker) for out-of-range values."""
@ -271,7 +276,7 @@ class Inst:
inherited = {}
for base in reversed(cls.__mro__[1:]):
if hasattr(base, '_fields'):
inherited.update({name: field for name, field in base._fields})
inherited.update(dict(base._fields))
inherited.update({name: val for name, val in cls.__dict__.items() if isinstance(val, BitField)})
cls._fields = list(inherited.items())
cls._base_size = (max(f.hi for _, f in cls._fields) + 8) // 8

View file

@ -70,7 +70,8 @@ def _split64(val: UOp) -> tuple[UOp, UOp]:
v64 = val.bitcast(dtypes.uint64) if val.dtype == dtypes.float64 else val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val
return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32)
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF)}
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF),
64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF)}
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp:
"""Apply abs/neg modifiers to source value based on bit width (16, 32, or 64)."""
if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val
@ -163,21 +164,30 @@ def get_pcode(op) -> str:
(f'1.0 / S1.{dt} == DENORM.{dt}', '0'), (f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})'),
(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}'),
(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}'),
(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})'),
(f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\nif S0.{dt} == S2.{dt} then\n// Only scale the numerator\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\nD0.{dt} = S0.{dt}'),
(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\nD0.{dt} = S0.{dt}\nendif\nelsif')]:
(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\n'
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\n'
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})'),
(f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\n'
f'if S0.{dt} == S2.{dt} then\n// Only scale the numerator\n'
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\n'
f'VCC = 0x1LL;\nD0.{dt} = S0.{dt}'),
(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif',
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\n'
f'D0.{dt} = S0.{dt}\nendif\nelsif')]:
pcode = pcode.replace(old, new)
lines = pcode.rstrip().split('\n')
for i in range(len(lines) - 1, -1, -1):
if lines[i].strip() == 'endif': lines.insert(i, f'else\nD0.{dt} = S0.{dt}'); break
if lines[i].strip() == 'endif':
lines.insert(i, f'else\nD0.{dt} = S0.{dt}')
break
pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif'
pcode = pcode.replace('VCC = 0x0LL', 'VCC.u64[laneId] = 0').replace('VCC = 0x1LL', 'VCC.u64[laneId] = 1')
return pcode
def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
vars: dict = srcs.copy() if srcs else {}
env: dict = srcs.copy() if srcs else {}
assigns: list[tuple[str, UOp]] = []
raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
# TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed
@ -185,15 +195,17 @@ def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, l
for l in raw_lines:
if lines and lines[-1].endswith('&&'): lines[-1] = lines[-1] + ' ' + l
else: lines.append(l)
_, final, _ = parse_block(lines, 0, vars, assigns=assigns)
_, final, _ = parse_block(lines, 0, env, assigns=assigns)
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
for var, val in final.items():
if var in ['D0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
if var in sliced and not any(re.match(rf'{var}\.\w+\s*=', l) for l in lines): continue
for l in lines:
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)): assigns.append((f'{var}.{m.group(1)}', val)); break
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)):
assigns.append((f'{var}.{m.group(1)}', val))
break
else: assigns.append((var, val))
return vars, assigns
return env, assigns
def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]:
"""Write a 64-bit value as two 32-bit writes. args passed to wfn after reg/addr and lo/hi value."""
@ -329,7 +341,8 @@ class _Ctx:
# Dynamic register access (takes UOp index instead of int)
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
"""Read SGPR with dynamic register index."""
return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load() if valid is not None else self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load()
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
"""Write SGPR with dynamic register index. Writes to NULL (124) are discarded."""
@ -475,7 +488,8 @@ class _Ctx:
if hi_bit != 31 or lo_bit != 0:
width, slice_mask = hi_bit - lo_bit + 1, (1 << (hi_bit - lo_bit + 1)) - 1
val_bits = val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.half else \
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else \
val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
raw_stores.append(('vgpr_slice', (lo_bit, width, val_bits)))
continue
# For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1]
@ -484,7 +498,8 @@ class _Ctx:
val = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0))
if val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
lo, hi = _split64(val)
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)), ('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))])
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)),
('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))])
elif val.dtype in (dtypes.half, dtypes.uint16, dtypes.int16):
result, old_val = _val_to_u32(val), self.rvgpr_dyn(vdst_reg, lane)
hi_result = (old_val & UOp.const(dtypes.uint32, 0xFFFF)) | (result << UOp.const(dtypes.uint32, 16))
@ -507,7 +522,7 @@ class _Ctx:
if lane_stores: stores.append(UOp.sink(*lane_stores).end(lane))
for mask_val, reg in [(vcc_val, vcc_reg), (exec_val, EXEC_LO.offset)]:
if mask_val is None: continue
get_bit = lambda l, v=mask_val: (_to_u32(v.substitute({lane: l})) & _c(1)).cast(dtypes.uint32)
def get_bit(l, v=mask_val): return (_to_u32(v.substitute({lane: l})) & _c(1)).cast(dtypes.uint32)
stores.append(self.wsgpr_dyn(_c(reg), self.unroll_lanes(get_bit, exec_mask, apply_exec=False)))
stores.extend(scalar_stores)
return UOp.sink(*stores, *self.inc_pc())
@ -546,7 +561,7 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
# Dynamic sdata field (bits 12:6) - destination SGPR
sdata_reg = ctx.inst_field(type(inst).sdata)
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr]
offset = ctx.inst_field_signed(offset_field) # signed immediate
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0)
soffset = ctx.inst_field(type(inst).soffset)
@ -561,7 +576,7 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp:
bits = inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
if isinstance(inst, (ir3.SOPK, ir4.SOPK)):
sdst_off = ctx.inst_field(type(inst).sdst)
@ -598,7 +613,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
op_name = _op_name(inst)
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
vdst_reg = ctx.inst_field(type(inst).vdst)
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
@ -641,7 +656,7 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
if is_vopc:
src0_off = ctx.inst_field(type(inst).src0)
vsrc1_off = ctx.inst_field(type(inst).vsrc1)
vsrc1_off = ctx.inst_field(type(inst).vsrc1) # type: ignore[union-attr]
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
if bits['s0'] == 16:
vsrc1_hi = vsrc1_off >= _c(128)
@ -651,16 +666,17 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op
src1_off = _c(256) + vsrc1_off
else:
src0_off = ctx.inst_field(type(inst).src0)
src1_off = ctx.inst_field(type(inst).src1)
dst_off = ctx.inst_field(type(inst).vdst)
src1_off = ctx.inst_field(type(inst).src1) # type: ignore[union-attr]
dst_off = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr]
vsrc1_hi = False
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op)
def get_cmp_bit(lane) -> UOp:
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
s0 = ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 \
else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
if is_float:
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0'])
@ -701,7 +717,7 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
# Regular VOP3 - read operands dynamically
lane = ctx.range()
vdst_reg = ctx.inst_field(type(inst).vdst)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
ops = inst.canonical_operands
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
@ -728,7 +744,7 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
# Read operands dynamically from instruction encoding
vdst_reg, sdst_off = ctx.inst_field(type(inst).vdst), ctx.inst_field(type(inst).sdst)
src0_off, src1_off, src2_off = ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
vcc_in_off = src2_off if has_carry_in else sdst_off
@ -853,7 +869,9 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
return bits
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
lo = get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit))
hi = get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit))
return lo | (hi << UOp.const(dtypes.uint32, 16))
# DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation
is_dot_iu = 'DOT' in op_name and 'IU' in op_name
n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4)
@ -908,11 +926,11 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
# Extract register info - all dynamic for deduplication
if is_lds:
addr_reg = ctx.inst_field(type(inst).addr)
vdata_reg = ctx.inst_field(type(inst).data0)
addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr]
vdata_reg = ctx.inst_field(type(inst).data0) # type: ignore[union-attr]
vdst_reg = ctx.inst_field(type(inst).vdst)
offset0 = ctx.inst_field(type(inst).offset0)
offset1 = ctx.inst_field(type(inst).offset1)
offset0 = ctx.inst_field(type(inst).offset0) # type: ignore[union-attr]
offset1 = ctx.inst_field(type(inst).offset1) # type: ignore[union-attr]
offset = offset0 # DS uses offset0 as primary offset
saddr_reg = None
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
@ -923,18 +941,18 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
offset0, offset1 = _c(0), _c(0)
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
else: # RDNA3: addr, data, offset
addr_reg = ctx.inst_field(type(inst).addr)
vdata_reg = ctx.inst_field(type(inst).data)
addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr]
vdata_reg = ctx.inst_field(type(inst).data) # type: ignore[union-attr]
vdst_reg = ctx.inst_field(type(inst).vdst)
offset = ctx.inst_field_signed(type(inst).offset)
offset = ctx.inst_field_signed(type(inst).offset) # type: ignore[union-attr]
offset0, offset1 = _c(0), _c(0)
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None # type: ignore[union-attr]
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
data_bits_mem = inst.canonical_op_bits.get('data', 32)
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0)
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0) # type: ignore[union-attr]
# DS_PERMUTE/DS_BPERMUTE: cross-lane VGPR access via pcode
if is_lds and 'PERMUTE' in op_name:
@ -958,7 +976,8 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
vaddr = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
addr_offset = vaddr if sve == 1 else UOp.const(dtypes.uint64, 0)
# Add saddr value only if use_saddr is true (saddr < 124)
saddr_contrib = use_saddr.where(ctx.rsgpr_dyn(saddr_reg).cast(dtypes.uint64), UOp.const(dtypes.uint64, 0)) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
saddr_contrib = use_saddr.where(ctx.rsgpr_dyn(saddr_reg).cast(dtypes.uint64), UOp.const(dtypes.uint64, 0)) \
if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
return base + addr_offset + saddr_contrib + offset64
# FLAT/GLOBAL: choose between SGPR base (saddr) or VGPR pair (addr) based on saddr validity
saddr_base = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
@ -1000,12 +1019,18 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
vaddr_lo = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
vaddr_base = use_saddr.where(vaddr_lo + ioffset64, vaddr_full + ioffset64)
if is_atomic:
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane),
'_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
atomic_data = _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) \
if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane)
return {'ADDR': addr, 'DATA': atomic_data, '_vmem': mem, '_active': active,
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name \
else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
if 'STORE' in op_name and data_bits_mem >= 64:
vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active,
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
for i in range(data_bits_mem // 32):
srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
return srcs
def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]:
@ -1199,7 +1224,9 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
if rsrc2 & enabled: st._write_sgpr(sgpr_idx, gid); sgpr_idx += 1
if rsrc2 & enabled:
st._write_sgpr(sgpr_idx, gid)
sgpr_idx += 1
# RDNA4 uses TTMP registers for workgroup IDs: ttmp[9]=gidx, ttmp[10]=gidy, ttmp[11]=gidz
if arch == "rdna4":

View file

@ -1,7 +1,7 @@
# AMD ISA code generator - generates enum.py, ins.py, operands.py, str_pcode.py
# Sources: XML from https://gpuopen.com/download/machine-readable-isa/latest/
# PDF manuals from AMD documentation
import re, zlib, xml.etree.ElementTree as ET, zipfile
import re, zlib, xml.etree.ElementTree as ET, zipfile, pathlib
from tinygrad.helpers import fetch
# ═══════════════════════════════════════════════════════════════════════════════
@ -77,8 +77,13 @@ def parse_xml(filename: str):
for ot in root.findall(".//OperandTypes/OperandType"):
ot_name = ot.findtext("OperandTypeName")
for field in ot.findall(".//Field"):
if (enum_name := op_enum_map.get((ot_name, field.findtext("FieldName")))):
enums[enum_name] = {int(pv.findtext("Value")): pv.findtext("Name").upper() for pv in field.findall(".//PredefinedValue")}
key = (ot_name, field.findtext("FieldName"))
if (enum_name := op_enum_map.get(key)): # type: ignore[arg-type]
def _pv_val(pv: ET.Element) -> tuple[int, str]:
v, n = pv.findtext("Value"), pv.findtext("Name")
assert v is not None and n is not None
return int(v), n.upper()
enums[enum_name] = dict(_pv_val(pv) for pv in field.findall(".//PredefinedValue"))
# Extract DataFormats with BitCount
for df in root.findall("ISA/DataFormats/DataFormat"):
name, bits = df.findtext("DataFormatName"), df.findtext("BitCount")
@ -86,17 +91,26 @@ def parse_xml(filename: str):
# Extract encoding definitions
for enc in root.findall("ISA/Encodings/Encoding"):
name = enc.findtext("EncodingName")
assert name is not None
is_base = name.startswith("ENC_") or name in ("VOP3_SDST_ENC", "VOPDXY")
is_variant = any(sfx in name for sfx in _ENC_SUFFIX_MAP)
if not is_base and not is_variant: continue
if any(s in name for s in _SKIP_ENCODINGS): continue
fields = [(_norm_field(f.findtext("FieldName").lower()), int(f.find("BitLayout/Range").findtext("BitOffset") or 0) + int(f.find("BitLayout/Range").findtext("BitCount") or 0) - 1,
int(f.find("BitLayout/Range").findtext("BitOffset") or 0))
for f in enc.findall(".//MicrocodeFormat/BitMap/Field") if f.find("BitLayout/Range") is not None]
ident = (enc.findall("EncodingIdentifiers/EncodingIdentifier") or [None])[0]
fields: list[tuple[str, int, int]] = []
for f in enc.findall(".//MicrocodeFormat/BitMap/Field"):
br = f.find("BitLayout/Range")
if br is None: continue
fn = f.findtext("FieldName")
assert fn is not None
fields.append((_norm_field(fn.lower()),
int(br.findtext("BitOffset") or 0) + int(br.findtext("BitCount") or 0) - 1, int(br.findtext("BitOffset") or 0)))
ident_list = enc.findall("EncodingIdentifiers/EncodingIdentifier")
ident = ident_list[0] if ident_list else None
enc_field = next((f for f in fields if f[0] == "encoding"), None)
# For multi-dword formats, encoding field may be in higher dword but identifier pattern is always in dword0; use % 32
enc_bits = "".join(ident.text[len(ident.text)-1-b] for b in range(enc_field[1] % 32, (enc_field[2] % 32)-1, -1)) if ident is not None and enc_field else None
# For multi-dword formats, encoding field may be in higher dword but identifier is always in dword0; use % 32
enc_bits: str | None = None
if ident is not None and ident.text is not None and enc_field:
enc_bits = "".join(ident.text[len(ident.text)-1-b] for b in range(enc_field[1] % 32, (enc_field[2] % 32)-1, -1))
base_name = _strip_enc(name)
encodings[NAME_MAP.get(base_name, base_name)] = (fields, enc_bits)
# Extract instruction opcodes and operand info
@ -104,9 +118,12 @@ def parse_xml(filename: str):
opcode_encs: dict[str, dict[int, set[str]]] = {} # {base_fmt: {opcode: {enc_names}}}
for instr in root.findall("ISA/Instructions/Instruction"):
name = instr.findtext("InstructionName")
assert name is not None
for enc in instr.findall("InstructionEncodings/InstructionEncoding"):
if enc.findtext("EncodingCondition") != "default": continue
base, opcode = _map_flat(_strip_enc(enc.findtext("EncodingName")), name), int(enc.findtext("Opcode") or 0)
enc_enc_name = enc.findtext("EncodingName")
assert enc_enc_name is not None
base, opcode = _map_flat(_strip_enc(enc_enc_name), name), int(enc.findtext("Opcode") or 0)
enc_name = NAME_MAP.get(base, base)
# Encoding variants use the same Op enum as the base format
base_enum = enc_name
@ -120,8 +137,10 @@ def parse_xml(filename: str):
elif base == "VGLOBAL": enums.setdefault("VFLAT", {})[opcode] = name
enums.setdefault(base_enum, {})[opcode] = name
# Extract operand info
op_info = {op.findtext("FieldName").lower(): (op.findtext("DataFormatName"), int(op.findtext("OperandSize") or 0), op.findtext("OperandType"))
for op in enc.findall("Operands/Operand") if op.findtext("FieldName")}
op_info: dict[str, tuple[str | None, int, str | None]] = {}
for op in enc.findall("Operands/Operand"):
fn = op.findtext("FieldName")
if fn: op_info[fn.lower()] = (op.findtext("DataFormatName"), int(op.findtext("OperandSize") or 0), op.findtext("OperandType"))
for fmt, _, otype in op_info.values():
if fmt and fmt not in fmts: fmts[fmt] = 0
if otype: op_types_set.add(otype)
@ -143,7 +162,9 @@ def extract_pdf_text(url: str) -> list[list[tuple[float, float, str, str]]]:
data = fetch(url).read_bytes()
# Parse xref table to locate objects
xref: dict[int, int] = {}
pos = int(re.search(rb'startxref\s+(\d+)', data).group(1)) + 4
xref_match = re.search(rb'startxref\s+(\d+)', data)
assert xref_match is not None
pos = int(xref_match.group(1)) + 4
while data[pos:pos+7] != b'trailer':
while data[pos:pos+1] in b' \r\n': pos += 1
line_end = data.find(b'\n', pos)
@ -164,14 +185,19 @@ def extract_pdf_text(url: str) -> list[list[tuple[float, float, str, str]]]:
if not (m := re.search(rb'/Contents (\d+) 0 R', data[xref[n]:xref[n]+500])): continue
stream = get_stream(int(m.group(1))).decode('latin-1')
elements, font = [], ''
_RE_BT = (r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm'
r'|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ')
for bt in re.finditer(r'BT(.*?)ET', stream, re.S):
x, y = 0.0, 0.0
for m in re.finditer(r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ', bt.group(1)):
if m.group(1): font = m.group(1)
elif m.group(2): x, y = x + float(m.group(2)), y + float(m.group(3))
elif m.group(4): x, y = float(m.group(4)), float(m.group(5))
elif m.group(6) and (t := bytes.fromhex(m.group(6)).decode('latin-1')).strip(): elements.append((x, y, t, font))
elif m.group(7) and (t := ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', m.group(7)))).strip(): elements.append((x, y, t, font))
for sm in re.finditer(_RE_BT, bt.group(1)):
if sm.group(1): font = sm.group(1)
elif sm.group(2): x, y = x + float(sm.group(2)), y + float(sm.group(3))
elif sm.group(4): x, y = float(sm.group(4)), float(sm.group(5))
elif sm.group(6) and (t := bytes.fromhex(sm.group(6)).decode('latin-1')).strip():
elements.append((x, y, t, font))
elif sm.group(7):
t = ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', sm.group(7)))
if t.strip(): elements.append((x, y, t, font))
pages.append(sorted(elements, key=lambda e: (-e[1], e[0])))
return pages
@ -197,7 +223,7 @@ def extract_pcode(pages: list[list[tuple[float, float, str, str]]], name_to_op:
else:
next_page, next_y = page_idx, 0
# Collect F6 text from current position to next instruction (pseudocode is at x ≈ 69)
lines = []
lines: list[tuple[int, float, str]] = []
for p in range(page_idx, next_page + 1):
start_y = y if p == page_idx else 800
end_y = next_y if p == next_page else 0
@ -220,8 +246,8 @@ def extract_pcode(pages: list[list[tuple[float, float, str, str]]], name_to_op:
# Code generation
# ═══════════════════════════════════════════════════════════════════════════════
def write_common(all_fmts, all_op_types, path):
lines = ["# autogenerated from AMD ISA XML - do not edit", "from enum import Enum, auto", ""]
def write_common(all_fmts: dict[str, int], all_op_types: set[str], path: pathlib.Path) -> None:
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit", "from enum import Enum, auto", ""]
lines.append("class ReprEnum(Enum):")
lines.append(' """Enum with clean repr that roundtrips with eval()."""')
lines.append(' def __repr__(self): return f"{type(self).__name__}.{self.name}"')
@ -238,7 +264,8 @@ def write_common(all_fmts, all_op_types, path):
with open(path, "w") as f: f.write("\n".join(lines))
def write_enum(enums, path):
lines = ["# autogenerated from AMD ISA XML - do not edit", "from extra.assembly.amd.autogen.common import ReprEnum, Fmt, FMT_BITS, OpType # noqa: F401", ""]
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit",
"from extra.assembly.amd.autogen.common import ReprEnum, Fmt, FMT_BITS, OpType # noqa: F401", ""]
for name, ops in sorted(enums.items()):
if not ops: continue
suffix = "_E32" if name in ("VOP1", "VOP2", "VOPC") else "_E64" if name == "VOP3" else ""
@ -286,7 +313,7 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
'dpp', 'fi', 'bc', 'row_mask', 'bank_mask', 'src0_neg', 'src0_abs', 'src1_neg', 'src1_abs',
'cbsz', 'abid', 'acc_cd', 'acc', 'blgp', 'lane_sel_0', 'lane_sel_1', 'lane_sel_2', 'lane_sel_3',
'lane_sel_4', 'lane_sel_5', 'lane_sel_6', 'lane_sel_7', 'dst_sel', 'dst_unused', 'src0_sel', 'src1_sel']
sort_fields = lambda fields: sorted(fields, key=lambda f: (ORDER.index(f[0]) if f[0] in ORDER else 999, f[2]))
def sort_fields(fields): return sorted(fields, key=lambda f: (ORDER.index(f[0]) if f[0] in ORDER else 999, f[2]))
# Separate base encodings from variants
base_encodings, variant_encodings = {}, {}
@ -296,15 +323,29 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
else: variant_encodings[enc_name] = data
# Build sets of ops by their vdst type from operand metadata
sdst_opcodes = {} # ops where vdst is OPR_SREG (writes to SGPR)
sdst_opcodes: dict[str, set[int]] = {} # ops where vdst is OPR_SREG (writes to SGPR)
for fmt, ops in enums.items():
for op, name in ops.items():
op_types = types.get((name, fmt), {})
vdst_type = op_types.get("vdst", (None, None, None))[2]
if vdst_type == "OPR_SREG": sdst_opcodes.setdefault(fmt, set()).add(op)
lines = ["# autogenerated from AMD ISA XML - do not edit", "# ruff: noqa: F401,F403",
"from extra.assembly.amd.dsl import *", f"from extra.assembly.amd.autogen.{arch}.enum import *", "import functools", ""]
# collect only the XxxOp enums that are actually referenced in this arch's instruction definitions
enum_names = sorted(f"{k}Op" for k in enums if enums[k] and k not in ("HWREG", "MSG"))
# also re-export HWREG/MSG enums (plain enums, not instruction format ops)
enum_names += sorted(k for k in enums if k in ("HWREG", "MSG") and enums[k])
# collect DSL field types actually used by scanning generated field definitions
all_field_defs = " ".join(field_def(fn, hi, lo, enc, eb) for enc, (flds, eb) in encodings.items() for fn, hi, lo in flds)
_ALL_DSL = ["BitField", "EnumBitField", "FixedBitField", "NULL", "SBaseField", "SGPRField", "SRsrcField",
"SSrcField", "SrcField", "VDSTYField", "VGPRField"]
dsl_names = ["Inst"] + [n for n in _ALL_DSL if n in all_field_defs]
# also re-export register names so `from ins import *` still provides them to downstream users
_DSL_REGS = ["s", "v", "src", "VCC_LO", "VCC_HI", "VCC", "EXEC_LO", "EXEC_HI", "EXEC", "NULL", "OFF", "M0",
"SCC", "VCCZ", "EXECZ", "ttmp", "INV_2PI", "SDWA", "DPP", "DPP16", "LIT", "SRC_LDS_DIRECT"]
dsl_reexport = sorted(set(dsl_names + _DSL_REGS))
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit", "# ruff: noqa: E501,F401",
f"from extra.assembly.amd.dsl import {', '.join(dsl_reexport)}",
f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", "import functools", ""]
def fmt_allowed(op_enum: str, ops: set[int]) -> str:
"""Format allowed ops as {EnumName.MEMBER, ...}."""
@ -323,7 +364,9 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
has_seg_field = any(fn == "seg" for fn, _, _ in fields)
if enc_name in ("FLAT", "VFLAT") and has_seg_field:
prefix = "V" if enc_name == "VFLAT" else ""
for cls, seg, op_enum in [(f"{prefix}FLAT", 0, f"{prefix}FLATOp"), (f"{prefix}GLOBAL", 2, f"{prefix}GLOBALOp"), (f"{prefix}SCRATCH", 1, f"{prefix}SCRATCHOp")]:
flat_variants = [(f"{prefix}FLAT", 0, f"{prefix}FLATOp"), (f"{prefix}GLOBAL", 2, f"{prefix}GLOBALOp"),
(f"{prefix}SCRATCH", 1, f"{prefix}SCRATCHOp")]
for cls, seg, op_enum in flat_variants:
cls_ops = set(enums.get(cls, {}).keys())
lines.append(f"class {cls}(Inst):")
for fn, hi, lo in sort_fields(fields):
@ -396,6 +439,8 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
op_to_suffix = {op:suffix for suffix,ops in suffix_only_ops.items() for op in ops.get(fmt, set())}
fmt_sdst_ops = sdst_opcodes.get(fmt, set())
for op, name in sorted(ops.items()):
# ADDTID ops are in both FLAT and GLOBAL enums (for pcode); only generate helper for GLOBAL/VGLOBAL
if "ADDTID" in name and fmt in ("FLAT", "VFLAT"): continue
msuf = suffix if fmt != "VOP3" or op < 512 else ""
# Determine class: SDST variants, suffix-specific variants (e.g., _MFMA, _LIT), or base
if fmt == "VOP1" and op in fmt_sdst_ops: cls = "VOP1_SDST"
@ -405,11 +450,14 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
lines.append(f"{name.lower()}{msuf.lower()} = functools.partial({cls}, {fmt}Op.{name}{msuf})")
with open(path, "w") as f: f.write("\n".join(lines))
def write_operands(types, enums, arch, path):
def write_operands(types: dict, enums: dict, arch: str, path: pathlib.Path) -> None:
valid = {(name, fmt) for fmt, ops in enums.items() for name in ops.values()}
lines = ["# autogenerated from AMD ISA XML - do not edit",
"from extra.assembly.amd.autogen.common import Fmt, OpType",
f"from extra.assembly.amd.autogen.{arch}.enum import *", ""]
# only import enums that are actually used as keys in OPERANDS
used_bases = {eb for (nm, eb) in types if (nm, eb) in valid}
enum_names = sorted(f"{k}Op" for k in used_bases)
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit",
"from extra.assembly.amd.autogen.common import Fmt, OpType",
f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", ""]
lines.append("# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}")
lines.append("OPERANDS = {")
def fmt_val(v):
@ -422,7 +470,7 @@ def write_operands(types, enums, arch, path):
lines.append("}")
with open(path, "w") as f: f.write("\n".join(lines))
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: str):
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: pathlib.Path) -> None:
"""Write str_pcode.py file from extracted pseudocode."""
entries: list[tuple[str, str, int, str]] = []
for fmt_name, ops in enums.items():
@ -444,8 +492,9 @@ def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, st
# ═══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
import pathlib
all_fmts, all_op_types, arch_data = {}, set(), {}
all_fmts: dict[str, int] = {}
all_op_types: set[str] = set()
arch_data: dict[str, dict] = {}
# First pass: parse XML for all architectures
for arch, cfg in ARCHS.items():
print(f"Parsing XML: {cfg['xml']} -> {arch}")

View file

@ -52,7 +52,9 @@ def _val_to_bits(val):
if val.dtype == dtypes.float64: return val.bitcast(dtypes.uint64)
return val if val.dtype == dtypes.uint32 else val.cast(dtypes.uint32)
def _floor(x): t = UOp(Ops.TRUNC, x.dtype, (x,)); return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t)
def _floor(x):
t = UOp(Ops.TRUNC, x.dtype, (x,))
return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t)
def _f16_extract(v): return (v & _u32(0xFFFF)).cast(dtypes.uint16).bitcast(dtypes.half) if v.dtype == dtypes.uint32 else v
def _check_nan(v: UOp, quiet: bool) -> UOp:
@ -118,7 +120,8 @@ def _f_to_u(f, dt): return UOp(Ops.TRUNC, f.dtype, ((f < _const(f.dtype, 0.0)).w
def _cvt_quiet(val: UOp) -> UOp:
bits, _, _, qb, _ = _float_info(val)
bt, ft = (dtypes.uint64, dtypes.float64) if val.dtype == dtypes.float64 else (dtypes.uint16, dtypes.half) if val.dtype == dtypes.half else (dtypes.uint32, dtypes.float32)
bt, ft = (dtypes.uint64, dtypes.float64) if val.dtype == dtypes.float64 else \
(dtypes.uint16, dtypes.half) if val.dtype == dtypes.half else (dtypes.uint32, dtypes.float32)
return (val.bitcast(bt) | qb).bitcast(ft)
def _is_denorm(val: UOp) -> UOp:
@ -163,14 +166,18 @@ def _ldexp(val: UOp, exp: UOp) -> UOp:
def _frexp_mant(val: UOp) -> UOp:
val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val
if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) & _u32(0x807FFFFF)) | _u32(0x3f000000)).bitcast(dtypes.float32)
return ((val.bitcast(dtypes.uint64) & _const(dtypes.uint64, 0x800FFFFFFFFFFFFF)) | _const(dtypes.uint64, 0x3fe0000000000000)).bitcast(dtypes.float64)
return ((val.bitcast(dtypes.uint64) & _const(dtypes.uint64, 0x800FFFFFFFFFFFFF)) |
_const(dtypes.uint64, 0x3fe0000000000000)).bitcast(dtypes.float64)
def _frexp_exp(val: UOp) -> UOp:
val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val
if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) >> _u32(23)) & _u32(0xFF)).cast(dtypes.int) - _const(dtypes.int, 126)
return ((val.bitcast(dtypes.uint64) >> _const(dtypes.uint64, 52)) & _const(dtypes.uint64, 0x7FF)).cast(dtypes.int) - _const(dtypes.int, 1022)
TWO_OVER_PI = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
TWO_OVER_PI = int(
"0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd"
"63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414"
"da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6", 16)
# TWO_OVER_PI as 19 u64 words for trig_preop_result (word[0] = bits 0-63, word[18] = bits 1152-1200)
_PREOP_WORDS = tuple((TWO_OVER_PI >> (64 * i)) & 0xFFFFFFFFFFFFFFFF for i in range(19))
def _trig_preop(val: UOp) -> UOp:
@ -247,10 +254,14 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
# Normalization conversions: map [-1,1] or [0,1] to integer range
# Use floor(x + 0.5) for round-to-nearest
# SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior)
'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f16_to_snorm': lambda a: _floor(
_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f16_to_unorm': lambda a: _floor(
_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_snorm': lambda a: _floor(
a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f32_to_unorm': lambda a: _floor(
a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8),
# Integer truncation conversions
'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16),
@ -310,21 +321,35 @@ _SINGLE_CHAR = {'(': 'LPAREN', ')': 'RPAREN', '[': 'LBRACKET', ']': 'RBRACKET',
class Token:
__slots__ = ('type', 'val')
def __init__(self, type: str, val: str): self.type, self.val = type, val
def __init__(self, kind: str, val: str): self.type, self.val = kind, val
def __repr__(self): return f'{self.type}:{self.val}'
def tokenize(s: str) -> list[Token]:
tokens, i, n = [], 0, len(s)
while i < n:
c = s[i]
if c.isspace(): i += 1; continue
if c.isspace():
i += 1
continue
if i + 1 < n and s[i:i+2] in ('+=', '-='):
tokens.append(Token('ASSIGN_OP', s[i:i+2])); i += 2; continue
tokens.append(Token('ASSIGN_OP', s[i:i+2]))
i += 2
continue
if i + 1 < n and s[i:i+2] in ('||', '&&', '>=', '<=', '==', '!=', '<>', '>>', '<<', '**', '+:', '-:'):
tokens.append(Token('OP', s[i:i+2])); i += 2; continue
if c in '|^&><+-*/~!%': tokens.append(Token('OP', c)); i += 1; continue
if (t := _SINGLE_CHAR.get(c)): tokens.append(Token(t, c)); i += 1; continue
if c == ';': i += 1; continue
tokens.append(Token('OP', s[i:i+2]))
i += 2
continue
if c in '|^&><+-*/~!%':
tokens.append(Token('OP', c))
i += 1
continue
if (t := _SINGLE_CHAR.get(c)):
tokens.append(Token(t, c))
i += 1
continue
if c == ';':
i += 1
continue
if c.isdigit() or (c == '-' and i + 1 < n and s[i+1].isdigit()):
start = i
if c == '-': i += 1
@ -337,31 +362,38 @@ def tokenize(s: str) -> list[Token]:
i += 1
while i < n and s[i].isdigit(): i += 1
for sfx in ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f'):
if s[i:i+len(sfx)] == sfx: i += len(sfx); break
tokens.append(Token('NUM', s[start:i])); continue
if s[i:i+len(sfx)] == sfx:
i += len(sfx)
break
tokens.append(Token('NUM', s[start:i]))
continue
if c.isalpha() or c == '_':
start = i
while i < n and (s[i].isalnum() or s[i] == '_'): i += 1
tokens.append(Token('IDENT', s[start:i])); continue
tokens.append(Token('IDENT', s[start:i]))
continue
raise RuntimeError(f"unexpected char '{c}' at pos {i} in: {s}")
tokens.append(Token('EOF', ''))
return tokens
class Parser:
def __init__(self, tokens: list[Token], vars: dict, funcs: dict | None = None):
self.tokens, self.vars, self.funcs, self.pos = tokens, vars, funcs if funcs is not None else _FUNCS, 0
def __init__(self, tokens: list[Token], env: dict, funcs: dict | None = None):
self.tokens, self.vars, self.funcs, self.pos = tokens, env, funcs if funcs is not None else _FUNCS, 0
def peek(self, offset=0) -> Token: return self.tokens[min(self.pos + offset, len(self.tokens) - 1)]
def at(self, *types) -> bool: return self.peek().type in types
def _advance(self) -> Token: tok = self.tokens[self.pos]; self.pos += 1; return tok
def eat(self, type: str) -> Token:
if self.peek().type != type: raise RuntimeError(f"expected {type}, got {self.peek()}")
def _advance(self) -> Token:
tok = self.tokens[self.pos]
self.pos += 1
return tok
def eat(self, kind: str) -> Token:
if self.peek().type != kind: raise RuntimeError(f"expected {kind}, got {self.peek()}")
return self._advance()
def try_eat(self, type: str) -> Token | None: return self._advance() if self.peek().type == type else None
def try_eat_val(self, val: str, type: str) -> Token | None:
return self._advance() if self.peek().type == type and self.peek().val == val else None
def eat_val(self, val: str, type: str) -> Token:
if self.peek().type != type or self.peek().val != val: raise RuntimeError(f"expected {type}:{val}, got {self.peek()}")
def try_eat(self, kind: str) -> Token | None: return self._advance() if self.peek().type == kind else None
def try_eat_val(self, val: str, kind: str) -> Token | None:
return self._advance() if self.peek().type == kind and self.peek().val == val else None
def eat_val(self, val: str, kind: str) -> Token:
if self.peek().type != kind or self.peek().val != val: raise RuntimeError(f"expected {kind}:{val}, got {self.peek()}")
return self._advance()
def parse(self) -> UOp:
@ -381,8 +413,10 @@ class Parser:
case '&&' | '&': return left & right
case '^': return left ^ right
case '==' | '<>': return left.eq(right) if op == '==' else left.ne(right)
case '!=' : return left.ne(right)
case '>=' | '<=' | '>' | '<': return self._cmp_nan(left, right, {'>=':(lambda a,b:a>=b),'<=':(lambda a,b:a<=b),'>':(lambda a,b:a>b),'<':(lambda a,b:a<b)}[op])
case '!=': return left.ne(right)
case '>=' | '<=' | '>' | '<':
ops = {'>=':(lambda a,b:a>=b),'<=':(lambda a,b:a<=b),'>':(lambda a,b:a>b),'<':(lambda a,b:a<b)}
return self._cmp_nan(left, right, ops[op])
case '>>' | '<<': return (left >> right) if op == '>>' else (left << right)
case '+' | '-':
if op == '-' and left.op == Ops.CONST and right.op == Ops.CONST: return _const(left.dtype, left.arg - right.arg)
@ -529,7 +563,8 @@ class Parser:
if dt is None: return base
if dt == base.dtype: return base
if dt.itemsize == 2 and base.dtype.itemsize == 4:
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
if dt == dtypes.uint16: return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16)
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
if field == 'i4': return _signext_4bit(base)
return _cast_to(base, dt)
@ -539,7 +574,7 @@ class Parser:
def _handle_bracket_rest(self, first: UOp, base: UOp, var_name: str | None = None) -> UOp:
if self.at('OP') and self.peek().val in ('+:', '-:'):
op = self.eat('OP').val
self.eat('OP')
width = self.parse()
self.eat('RBRACKET')
if width.op == Ops.CONST:
@ -625,7 +660,8 @@ class Parser:
inner = self.parse()
self.eat('RPAREN')
dt = {('U',32): dtypes.uint32, ('U',64): dtypes.uint64, ('I',32): dtypes.int, ('I',64): dtypes.int64,
('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64, ('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32)
('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64,
('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32)
if type_char == 'F' and inner.dtype in (dtypes.uint32, dtypes.uint64, dtypes.ulong, dtypes.int, dtypes.int64):
if inner.dtype.itemsize != dt.itemsize: inner = inner.cast(dtypes.uint32 if dt.itemsize == 4 else dtypes.uint64)
return inner.bitcast(dt)
@ -686,7 +722,7 @@ class Parser:
def _call_func(self, name: str, args: list[UOp]) -> UOp:
if name in self.vars and isinstance(self.vars[name], tuple) and self.vars[name][0] == 'lambda':
_, params, body = self.vars[name]
lv = {**self.vars, **{p: a for p, a in zip(params, args)}}
lv = {**self.vars, **dict(zip(params, args))}
if ';' in body or '\n' in body or 'return' in body.lower():
lines = [l.strip() for l in body.replace(';', '\n').split('\n') if l.strip() and not l.strip().startswith('//')]
_, _, result = parse_block(lines, 0, lv, self.funcs)
@ -712,7 +748,9 @@ class Parser:
elif dt in (dtypes.uint8, dtypes.int8):
val = mem.index(idx, *gate, ptr=True).load().cast(dt)
elif dt in (dtypes.uint16, dtypes.int16, dtypes.short):
val = (mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32) | (mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(8))).cast(dt)
lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32)
hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32)
val = (lo | (hi << _u32(8))).cast(dt)
else:
val = _u32(0)
for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
@ -723,7 +761,8 @@ class Parser:
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32))
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
elif dt in (dtypes.uint16, dtypes.int16): val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
elif dt in (dtypes.uint16, dtypes.int16):
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
return val
def _coerce_cmp(self, l: UOp, r: UOp) -> tuple[UOp, UOp]:
@ -756,8 +795,8 @@ def _match_bracket(toks: list[Token], start: int) -> tuple[int, list[Token]]:
return j, [t for t in toks[start+1:j-1] if t.type != 'EOF']
def _tok_str(toks: list[Token]) -> str: return ' '.join(t.val for t in toks if t.type != 'EOF')
def parse_tokens(toks: list[Token], vars: dict[str, VarVal], funcs: dict | None = None) -> UOp:
return Parser(toks, vars, funcs).parse()
def parse_tokens(toks: list[Token], env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
return Parser(toks, env, funcs).parse()
# Unified block parser for pcode
def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
@ -781,7 +820,7 @@ def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str =
if depth == 0: return j
return len(s)
def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: dict | None = None,
def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dict | None = None,
assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]:
"""Parse a block of pcode. Returns (next_line, block_assigns, return_value).
If assigns list is provided, side effects (MEM/VGPR writes) are appended to it."""
@ -792,7 +831,9 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
while i < len(lines):
line = lines[i]
toks = tokenize(line)
if toks[0].type != 'IDENT' and toks[0].type != 'LBRACE': i += 1; continue
if toks[0].type != 'IDENT' and toks[0].type != 'LBRACE':
i += 1
continue
first = toks[0].val.lower() if toks[0].type == 'IDENT' else '{'
# Block terminators
@ -801,17 +842,19 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# return expr (lambda bodies)
if first == 'return':
rest = line[line.lower().find('return') + 6:].strip()
return i + 1, block_assigns, parse_expr(rest, vars, funcs)
return i + 1, block_assigns, parse_expr(rest, env, funcs)
# for loop
if first == 'for':
# Parse: for VAR in [SIZE']START : [SIZE']END do
p = Parser(toks, vars, funcs)
p = Parser(toks, env, funcs)
p.eat_val('for', 'IDENT')
loop_var = p.eat('IDENT').val
p.eat_val('in', 'IDENT')
def parse_bound():
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
if p.at('NUM') and p.peek(1).type == 'QUOTE':
p.eat('NUM')
p.eat('QUOTE')
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
expr = p.parse().simplify()
assert expr.op == Ops.CONST, f"loop bound must be constant, got {expr}"
@ -833,38 +876,41 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# Execute loop with break support
has_break = any('break' in bl.lower() for bl in body_lines)
found_var = f'_found_{id(body_lines)}' if has_break else None
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
if found_var: env[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
for loop_i in range(start_val, end_val + 1):
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
_, iter_assigns, _ = parse_block(subst_lines, 0, {**env, **block_assigns}, funcs, assigns)
if has_break:
assert found_var is not None
found = block_assigns.get(found_var, vars.get(found_var))
found = block_assigns.get(found_var, env.get(found_var))
assert isinstance(found, UOp)
not_found = found.eq(_const(dtypes.bool, False))
for var, val in iter_assigns.items():
if var != found_var and isinstance(val, UOp):
old = block_assigns.get(var, vars.get(var, _u32(0)))
old = block_assigns.get(var, env.get(var, _u32(0)))
if isinstance(old, UOp):
block_assigns[var] = vars[var] = not_found.where(val, old.cast(val.dtype) if val.dtype != old.dtype and val.dtype.itemsize == old.dtype.itemsize else old)
block_assigns[var] = env[var] = not_found.where(
val, old.cast(val.dtype) if val.dtype != old.dtype and val.dtype.itemsize == old.dtype.itemsize else old)
for j, bl in enumerate(body_lines):
bl_l = bl.strip().lower()
if bl_l.startswith('if ') and bl_l.endswith(' then'):
if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))):
cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i)
cond = _to_bool(parse_expr(cond_str, vars, funcs))
block_assigns[found_var] = vars[found_var] = not_found.where(cond, found)
cond = _to_bool(parse_expr(cond_str, env, funcs))
block_assigns[found_var] = env[found_var] = not_found.where(cond, found)
break
else:
block_assigns.update(iter_assigns); vars.update(iter_assigns)
block_assigns.update(iter_assigns)
env.update(iter_assigns)
continue
# declare
if first == 'declare':
# Initialize scalar declarations (skip arrays and vars already passed as srcs)
# Initialize scalar declarations (skip arrays and env already passed as srcs)
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT':
vars.setdefault(toks[1].val, _u32(0))
i += 1; continue
env.setdefault(toks[1].val, _u32(0))
i += 1
continue
# lambda definition
if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks):
@ -886,26 +932,30 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if ch == '(': depth += 1
elif ch == ')':
depth -= 1
if depth == 0: body_lines_lst.append(lines[i][:j]); break
if depth == 0:
body_lines_lst.append(lines[i][:j])
break
else: body_lines_lst.append(lines[i])
i += 1
body = '\n'.join(body_lines_lst).strip()
vars[name] = ('lambda', params, body)
env[name] = ('lambda', params, body)
continue
# MEM assignment: MEM[addr].type (+|-)?= value
if first == 'mem' and toks[1].type == 'LBRACKET':
j, addr_toks = _match_bracket(toks, 1)
addr = parse_tokens(addr_toks, vars, funcs)
addr = parse_tokens(addr_toks, env, funcs)
if j < len(toks) and toks[j].type == 'DOT': j += 1
dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32'
dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1
compound_op = None
if j < len(toks) and toks[j].type == 'ASSIGN_OP': compound_op = toks[j].val; j += 1
if j < len(toks) and toks[j].type == 'ASSIGN_OP':
compound_op = toks[j].val
j += 1
elif j < len(toks) and toks[j].type == 'EQUALS': j += 1
rhs = parse_tokens(toks[j:], vars, funcs)
rhs = parse_tokens(toks[j:], env, funcs)
if compound_op:
mem = vars.get('_vmem') if '_vmem' in vars else vars.get('_lds')
mem = env.get('_vmem') if '_vmem' in env else env.get('_lds')
if isinstance(mem, UOp):
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
idx = (addr >> _const(adt, 2)).cast(dtypes.int)
@ -914,7 +964,8 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
old = old.cast(dtypes.uint64) | (mem.index(((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)).cast(dtypes.uint64) << _u64(32))
rhs = (old + rhs) if compound_op == '+=' else (old - rhs)
if assigns is not None: assigns.append((f'MEM[{_tok_str(addr_toks)}].{dt_name}', (addr, rhs)))
i += 1; continue
i += 1
continue
# VGPR assignment: VGPR[lane][reg] = value
if first == 'vgpr' and toks[1].type == 'LBRACKET':
@ -923,9 +974,12 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
j, reg_toks = _match_bracket(toks, j)
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs)
if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
i += 1; continue
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
if assigns is not None:
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
i += 1
continue
# Compound destination: {hi.type, lo.type} = value
if first == '{':
@ -939,18 +993,20 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
j += 3
if j < len(toks) and toks[j].type == 'RBRACE': j += 1
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
val = parse_tokens(toks[j:], vars, funcs)
val = parse_tokens(toks[j:], env, funcs)
lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32)
lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32
lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt)
hi_val = (val >> _const(val.dtype, lo_bits)).cast(hi_dt)
block_assigns[lo_var] = vars[lo_var] = lo_val
block_assigns[hi_var] = vars[hi_var] = hi_val
block_assigns[lo_var] = env[lo_var] = lo_val
block_assigns[hi_var] = env[hi_var] = hi_val
if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)])
i += 1; continue
i += 1
continue
# Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and (toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
if len(toks) >= 5 and toks[0].type == 'IDENT' and \
(toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
bracket_start = 2 if toks[1].type == 'LBRACKET' else 4
j = bracket_start
colon_pos = None
@ -967,23 +1023,28 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
j += 1
if j < len(toks) and toks[j].type == 'DOT': j += 2
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
val = parse_tokens(toks[j:], vars, funcs)
val = parse_tokens(toks[j:], env, funcs)
dt_suffix = toks[2].val if toks[1].type == 'DOT' else None
if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val))
if var not in vars: vars[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
old = block_assigns.get(var, vars.get(var))
block_assigns[var] = vars[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
i += 1; continue
except: pass
if var not in env: env[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
old = block_assigns.get(var, env.get(var))
assert isinstance(old, UOp)
block_assigns[var] = env[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
i += 1
continue
except Exception: pass
elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...])
existing = block_assigns.get(var, vars.get(var))
if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)):
existing = block_assigns.get(var, env.get(var))
if existing is not None and isinstance(existing, UOp) and \
not any(f'{var}{k}' in env or f'{var}{k}' in block_assigns for k in range(8)):
bit_toks = toks[2:j]
j += 1
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs))
i += 1; continue
block_assigns[var] = env[var] = _set_bit(
existing, _to_u32(parse_tokens(bit_toks, env, funcs)), parse_tokens(toks[j+1:], env, funcs))
i += 1
continue
# Array element: var[idx] = value (static index) or var[expr] = value (dynamic)
if len(toks) >= 4 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
@ -993,80 +1054,90 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# Static index: var[NUM] = value
if len(idx_toks) == 1 and idx_toks[0].type == 'NUM':
idx = int(idx_toks[0].val.rstrip('UuLl'))
val = parse_tokens(toks[j+1:], vars, funcs)
existing = block_assigns.get(var, vars.get(var))
val = parse_tokens(toks[j+1:], env, funcs)
existing = block_assigns.get(var, env.get(var))
if existing is not None and isinstance(existing, UOp):
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
block_assigns[var] = env[var] = _set_bit(existing, _u32(idx), val)
else:
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
i += 1; continue
block_assigns[f'{var}@{idx}'] = env[f'{var}@{idx}'] = val
i += 1
continue
# Dynamic index: var[expr] = value where var has @-elements
elems = [(k.split('@')[1], v) for k, v in {**vars, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
elems = [(k.split('@')[1], v) for k, v in {**env, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
if elems:
idx_expr = parse_tokens(idx_toks, vars, funcs)
val = parse_tokens(toks[j+1:], vars, funcs)
idx_expr = parse_tokens(idx_toks, env, funcs)
val = parse_tokens(toks[j+1:], env, funcs)
for elem_idx_str, old_elem in elems:
elem_idx = int(elem_idx_str)
cond = _to_u32(idx_expr).eq(_u32(elem_idx))
new_val = cond.where(val.cast(old_elem.dtype) if val.dtype != old_elem.dtype else val, old_elem)
block_assigns[f'{var}@{elem_idx}'] = vars[f'{var}@{elem_idx}'] = new_val
i += 1; continue
block_assigns[f'{var}@{elem_idx}'] = env[f'{var}@{elem_idx}'] = new_val
i += 1
continue
# Compound assignment: var += or var -=
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
if assign_op is not None:
var = toks[0].val
old = block_assigns.get(var, vars.get(var, _u32(0)))
rhs = parse_tokens(toks[assign_op+1:], vars, funcs)
old = block_assigns.get(var, env.get(var, _u32(0)))
rhs = parse_tokens(toks[assign_op+1:], env, funcs)
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
block_assigns[var] = vars[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
i += 1; continue
block_assigns[var] = env[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
i += 1
continue
# Typed element: var.type[idx] = value
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and \
toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
dt = DTYPES.get(dt_name, dtypes.uint32)
j = 6
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val, old = parse_tokens(toks[j+1:], vars, funcs), block_assigns.get(var, vars.get(var, _u32(0)))
val, old = parse_tokens(toks[j+1:], env, funcs), block_assigns.get(var, env.get(var, _u32(0)))
bw = dt.itemsize * 8
block_assigns[var] = vars[var] = _set_bits(old, val, bw, idx * bw)
block_assigns[var] = env[var] = _set_bits(old, val, bw, idx * bw)
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
i += 1; continue
i += 1
continue
# Dynamic bit: var.type[expr_with_brackets] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and \
toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
j, depth, has_inner = 4, 1, False
while j < len(toks) and depth > 0:
if toks[j].type == 'LBRACKET': depth += 1; has_inner = True
if toks[j].type == 'LBRACKET':
depth += 1
has_inner = True
elif toks[j].type == 'RBRACKET': depth -= 1
j += 1
if has_inner:
var = toks[0].val
bit_pos = _to_u32(parse_tokens(toks[4:j-1], vars, funcs))
bit_pos = _to_u32(parse_tokens(toks[4:j-1], env, funcs))
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val = parse_tokens(toks[j+1:], vars, funcs)
old = block_assigns.get(var, vars.get(var, _u32(0)))
block_assigns[var] = vars[var] = _set_bit(old, bit_pos, val)
i += 1; continue
val = parse_tokens(toks[j+1:], env, funcs)
old = block_assigns.get(var, env.get(var, _u32(0)))
block_assigns[var] = env[var] = _set_bit(old, bit_pos, val)
i += 1
continue
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
if first == 'if':
def parse_cond(s, kw):
ll = s.lower()
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs))
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), env, funcs))
def is_const(c, v): return c.op == Ops.CONST and c.arg is v
cond = parse_cond(line, 'if')
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else []
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
vars_snap = dict(vars)
env_snap = dict(env)
static_true = is_const(cond, True) # track if any condition is statically true
i += 1
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not is_const(cond, False) else None)
i, branch, ret = parse_block(lines, i, env, funcs, assigns if not is_const(cond, False) else None)
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
vars.clear(); vars.update(vars_snap)
env.clear()
env.update(env_snap)
while i < len(lines):
ltoks = tokenize(lines[i])
if ltoks[0].type != 'IDENT': break
@ -1074,17 +1145,22 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if lf == 'elsif':
c = parse_cond(lines[i], 'elsif')
take = not static_true and not is_const(c, False)
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns if take else None)
i += 1
i, branch, ret = parse_block(lines, i, env, funcs, assigns if take else None)
if take:
conditions.append((c, ret if ret is not None else branch))
if is_const(c, True): static_true = True
vars.clear(); vars.update(vars_snap)
env.clear()
env.update(env_snap)
elif lf == 'else':
i += 1
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not static_true else None)
i, branch, ret = parse_block(lines, i, env, funcs, assigns if not static_true else None)
if not static_true: else_branch = (ret, branch)
vars.clear(); vars.update(vars_snap)
elif lf == 'endif': i += 1; break
env.clear()
env.update(env_snap)
elif lf == 'endif':
i += 1
break
else: break
# Check if any branch returned a value (lambda-style)
if any(isinstance(br, UOp) for _, br in conditions):
@ -1097,18 +1173,19 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# If statically true, use that branch directly; otherwise merge with WHERE
if static_true:
ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {})
block_assigns.update(ba); vars.update(ba)
block_assigns.update(ba)
env.update(ba)
else:
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
for var in all_vars:
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
for cond, ba in reversed(conditions):
res: Any = else_assigns.get(var, block_assigns.get(var, env.get(var, _u32(0))))
for cond, ba in reversed(conditions): # type: ignore[assignment]
if isinstance(ba, dict) and var in ba:
tv = ba[var]
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = vars[var] = res
block_assigns[var] = env[var] = res
continue
# Regular assignment: var = value
@ -1116,11 +1193,12 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if t.type == 'EQUALS':
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
base_var = toks[0].val
block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], vars, funcs)
i += 1; break
block_assigns[base_var] = env[base_var] = parse_tokens(toks[j+1:], env, funcs)
i += 1
break
else: i += 1
return i, block_assigns, None
def parse_expr(expr: str, vars: dict[str, VarVal], funcs: dict | None = None) -> UOp:
return parse_tokens(tokenize(expr.strip().rstrip(';')), vars, funcs)
def parse_expr(expr: str, env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
return parse_tokens(tokenize(expr.strip().rstrip(';')), env, funcs)

View file

@ -125,8 +125,8 @@ class PacketType:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._fields = {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField)}
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4)
cls._fields = {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField)} # type: ignore[attr-defined]
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4) # type: ignore[attr-defined]
@classmethod
def from_raw(cls, raw: int, time: int = 0):
@ -135,7 +135,7 @@ class PacketType:
return inst
def __repr__(self) -> str:
fields_str = ", ".join(f"{k}={getattr(self, k)}" for k in self._fields if not k.startswith('_') and k != 'encoding')
fields_str = ", ".join(f"{k}={getattr(self, k)}" for k in self._fields if not k.startswith('_') and k != 'encoding') # type: ignore[attr-defined]
return f"{self.__class__.__name__}({fields_str})"
# ═══════════════════════════════════════════════════════════════════════════════
@ -514,7 +514,7 @@ def _build_decode_tables(packet_types: dict[int, type[PacketType]]) -> tuple[dic
for opcode, pkt_cls in packet_types.items():
delta_field = getattr(pkt_cls, 'delta', None)
special = _special.get(pkt_cls, 0)
decode_info[opcode] = (pkt_cls, pkt_cls._size_nibbles, delta_field.lo if delta_field else 0, delta_field.mask if delta_field else 0, special)
decode_info[opcode] = (pkt_cls, pkt_cls._size_nibbles, delta_field.lo if delta_field else 0, delta_field.mask if delta_field else 0, special) # type: ignore[attr-defined]
return decode_info, state_table
_DECODE_INFO_RDNA3, _STATE_TABLE_RDNA3 = _build_decode_tables(PACKET_TYPES_RDNA3)

View file

@ -158,7 +158,7 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
for i, buf in enumerate(lowered.bufs):
if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'):
try: buf_data[i] = bytes(buf.base._buf)
except: pass
except Exception: pass
# Extract rsrc2 from ELF (same as ops_amd.py)
group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
lds_size = ((group_segment_size + 511) // 512) & 0x1FF
@ -232,7 +232,8 @@ def main():
total_work = n_insts * n_workgroups * n_threads
print(f"{n_insts} insts ({n_compiled} unique) × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size,
args_ptr, rsrc2, args.iterations) if rust_remu else None
if py_compile is not None:
py_exec_rate = total_work / py_exec / 1e6

View file

@ -1,14 +1,16 @@
# RDNA3/RDNA4/CDNA disassembler
from __future__ import annotations
import re, struct
import re
from typing import Callable
from extra.assembly.amd.dsl import Inst, Reg
# Special register mappings for disassembly
SPECIAL_GPRS = {106: 'vcc_lo', 107: 'vcc_hi', 124: 'null', 125: 'm0', 126: 'exec_lo', 127: 'exec_hi',
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0', 246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc'}
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0',
246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc'}
SPECIAL_GPRS_CDNA = {106: 'vcc_lo', 107: 'vcc_hi', 124: 'm0', 126: 'exec_lo', 127: 'exec_hi',
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0', 246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc',
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0',
246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc',
102: 'flat_scratch_lo', 103: 'flat_scratch_hi', 104: 'xnack_mask_lo', 105: 'xnack_mask_hi',
251: 'src_vccz', 252: 'src_execz'}
SPECIAL_PAIRS = {106: 'vcc', 126: 'exec'}
@ -70,7 +72,9 @@ def _num_srcs(inst) -> int:
if any(x in n for x in ('FMA', 'MAD', 'CNDMASK', 'BFE', 'BFI', 'LERP', 'MED3', 'SAD', 'DIV_FMAS', 'DIV_FIXUP', 'DIV_SCALE', 'CUBE')): return 3
# PERMLANE_VAR ops are 2-source, but PERMLANE (non-VAR) are 3-source
if 'PERMLANE' in n and '_VAR' not in n: return 3
if any(x in n for x in ('_ADD3', '_LSHL_ADD', '_ADD_LSHL', '_LSHL_OR', '_AND_OR', 'OR3_B32', 'AND_OR_B32', 'ALIGNBIT', 'ALIGNBYTE', 'V_PERM_', 'XOR3', 'XAD', 'MULLIT', 'MINMAX', 'MAXMIN', 'MINIMUMMAXIMUM', 'MAXIMUMMINIMUM', 'MINIMUM3', 'MAXIMUM3', 'MIN3', 'MAX3', 'DOT2', 'CVT_PK_U8_F32', 'DOT4', 'DOT8', 'WMMA', 'SWMMAC')): return 3
if any(x in n for x in ('_ADD3', '_LSHL_ADD', '_ADD_LSHL', '_LSHL_OR', '_AND_OR', 'OR3_B32', 'AND_OR_B32', 'ALIGNBIT',
'ALIGNBYTE', 'V_PERM_', 'XOR3', 'XAD', 'MULLIT', 'MINMAX', 'MAXMIN', 'MINIMUMMAXIMUM', 'MAXIMUMMINIMUM',
'MINIMUM3', 'MAXIMUM3', 'MIN3', 'MAX3', 'DOT2', 'CVT_PK_U8_F32', 'DOT4', 'DOT8', 'WMMA', 'SWMMAC')): return 3
return 2
# ═══════════════════════════════════════════════════════════════════════════════
@ -80,13 +84,14 @@ def _num_srcs(inst) -> int:
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP1_SDST_LIT, VOP1_LIT, VOP2, VOP2_LIT, VOP3, VOP3_SDST, VOP3_SDST_LIT,
VOP3_LIT, VOP3SD, VOP3SD_LIT, VOP3P, VOP3P_LIT, VOPC, VOPC_LIT, VOPD, VOPD_LIT, VINTERP, SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPC_LIT,
SOPK, SOPK_LIT, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH, VOP2Op, VOPDOp, SOPPOp, HWREG, MSG)
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP1_SDST_LIT as R4_VOP1_SDST_LIT, VOP1_LIT as R4_VOP1_LIT,
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST,
VOP1_SDST_LIT as R4_VOP1_SDST_LIT, VOP1_LIT as R4_VOP1_LIT,
VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3_SDST_LIT as R4_VOP3_SDST_LIT, VOP3_LIT as R4_VOP3_LIT,
VOP3SD as R4_VOP3SD, VOP3SD_LIT as R4_VOP3SD_LIT, VOP3P as R4_VOP3P, VOP3P_LIT as R4_VOP3P_LIT, VOPC as R4_VOPC, VOPC_LIT as R4_VOPC_LIT,
VOPD as R4_VOPD, VOPD_LIT as R4_VOPD_LIT, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT, SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT,
SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT, SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP, SMEM as R4_SMEM, DS as R4_DS,
VOPDOp as R4_VOPDOp, HWREG as HWREG_RDNA4, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH)
from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, HWREG as HWREG_CDNA
from extra.assembly.amd.autogen.cdna.ins import HWREG as HWREG_CDNA
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
def _is_r4(inst: Inst) -> bool: return 'rdna4' in inst.__class__.__module__
@ -100,9 +105,15 @@ _CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{_unwrap(b)}" if n == 1 else f"{p}[{_unwrap(b)}:{_unwrap(b)+n-1}]"
def _sreg(b: int, n: int = 1) -> str: return _reg("s", _unwrap(b), n)
def _vreg(b: int, n: int = 1) -> str: b = _unwrap(b); return _reg("v", b - 256 if b >= 256 else b, n)
def _areg(b: int, n: int = 1) -> str: b = _unwrap(b); return _reg("a", b - 256 if b >= 256 else b, n) # accumulator registers for GFX90a
def _ttmp(b, n: int = 1) -> str | None: b = _unwrap(b); return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _vreg(b: int, n: int = 1) -> str:
b = _unwrap(b)
return _reg("v", b - 256 if b >= 256 else b, n)
def _areg(b: int, n: int = 1) -> str:
b = _unwrap(b)
return _reg("a", b - 256 if b >= 256 else b, n) # accumulator registers for GFX90a
def _ttmp(b, n: int = 1) -> str | None:
b = _unwrap(b)
return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _fmt_sdst(v, n: int = 1, cdna: bool = False) -> str:
v = _unwrap(v)
@ -130,7 +141,9 @@ def _fmt_v16(v, base: int = 256, hi_thresh: int = 384) -> str:
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: v = _unwrap(v); return _fmt_v16(v) if v >= 256 else _lit(inst, v) # format 16-bit src: vgpr.h/l or literal
def _src16(inst, v: int) -> str:
v = _unwrap(v)
return _fmt_v16(v) if v >= 256 else _lit(inst, v) # format 16-bit src: vgpr.h/l or literal
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
@ -201,7 +214,8 @@ def _disasm_vop2(inst: VOP2) -> str:
basename = name.replace('_e32', '')
if cdna and basename in _VOP2_CARRY_OUT: return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}"
if cdna and basename in _VOP2_CARRY_INOUT: return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}, {vcc}"
if not cdna and basename in _VOP2_CARRY_INOUT_RDNA: return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}, {vcc}"
if not cdna and basename in _VOP2_CARRY_INOUT_RDNA:
return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}, {vcc}"
sn0 = inst.canonical_op_regs.get('s0', 1)
if inst.vdst.sz > 1 or sn0 > 1 or inst.vsrc1.sz > 1:
src0 = _lit(inst, inst.src0) if inst.src0.offset == 255 else _fmt_src(inst.src0, sn0, cdna)
@ -217,7 +231,10 @@ def _disasm_vopc(inst: VOPC) -> str:
return f"{name} vcc, {s0}, {inst.vsrc1.fmt()}" # CDNA VOPC always outputs vcc
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
has_vcc = 'cmpx' not in name
s0 = _lit(inst, inst.src0) if inst.src0.offset == 255 else inst.src0.fmt() if inst.src0.sz > 1 else _src16(inst, inst.src0.offset) if is16 else _lit(inst, inst.src0)
if inst.src0.offset == 255: s0 = _lit(inst, inst.src0)
elif inst.src0.sz > 1: s0 = inst.src0.fmt()
elif is16: s0 = _src16(inst, inst.src0.offset)
else: s0 = _lit(inst, inst.src0)
s1 = inst.vsrc1.fmt() if inst.vsrc1.sz > 1 else _fmt_v16(inst.vsrc1) if is16 else inst.vsrc1.fmt()
suf = "" if name.endswith('_e32') else "_e32"
return f"{name}{suf} vcc_lo, {s0}, {s1}" if has_vcc else f"{name}{suf} {s0}, {s1}"
@ -253,10 +270,11 @@ def _disasm_sopp(inst: SOPP) -> str:
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if name == 's_delay_alu':
deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3']
deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2',
'TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3']
skips = ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
def dep(v): return deps[v-1] if 0 < v <= len(deps) else str(v)
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}"
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
@ -267,7 +285,7 @@ def _disasm_smem(inst: SMEM) -> str:
if name in ('s_gl1_inv', 's_dcache_inv', 's_dcache_inv_vol', 's_dcache_wb', 's_dcache_wb_vol', 's_icache_inv'): return name
soe, imm = getattr(inst, 'soe', 0) or getattr(inst, 'soffset_en', 0), getattr(inst, 'imm', 1)
is_rdna4 = _is_r4(inst)
offset = inst.ioffset if is_rdna4 else getattr(inst, 'offset', 0)
offset = inst.ioffset if is_rdna4 else getattr(inst, 'offset', 0) # type: ignore[attr-defined]
if cdna:
if soe and imm: off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{offset:x}"
elif imm: off_s = f"0x{offset:x}"
@ -278,7 +296,9 @@ def _disasm_smem(inst: SMEM) -> str:
else: off_s = decode_src(inst.soffset, cdna)
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
sbase_idx, sbase_count = _unwrap(inst.sbase), 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if sbase_count == 2: sbase_str = _fmt_src(sbase_idx, sbase_count, cdna)
elif sbase_idx <= 105: sbase_str = _sreg(sbase_idx, sbase_count)
else: sbase_str = _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {_unwrap(inst.sdata)}, {sbase_str}, {off_s}"
if 'prefetch' in name:
off = getattr(inst, 'ioffset', getattr(inst, 'offset', 0))
@ -312,7 +332,7 @@ def _disasm_flat(inst: FLAT) -> str:
else: seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
# Global/scratch uses 13-bit signed offset
offset = inst.ioffset if r4 else inst.offset
offset = inst.ioffset if r4 else inst.offset # type: ignore[attr-defined]
if seg != 'flat':
if cdna:
# CDNA: bit 12 is sign bit but not in offset field
@ -327,19 +347,20 @@ def _disasm_flat(inst: FLAT) -> str:
regs = inst.canonical_op_regs
w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1)
off_s = f" offset:{off_val}" if off_val else ""
if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}"
elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}"
if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}" # type: ignore[attr-defined]
elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}" # type: ignore[attr-defined]
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
if seg == 'flat': saddr_s = ""
elif _unwrap(inst.saddr) in (0x7F, 124): saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
elif _unwrap(inst.saddr) in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[_unwrap(inst.saddr)]}"
elif _unwrap(inst.saddr) in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS):
saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[_unwrap(inst.saddr)]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if _unwrap(inst.saddr) < 106 else decode_src(_unwrap(inst.saddr), cdna)}"
if 'addtid' in name: return f"{instr} {reg_fn(inst.data if 'store' in name else inst.vdst)}{saddr_s}{mods}"
# RDNA4: vaddr instead of addr, vsrc instead of data
addr = inst.vaddr if r4 else inst.addr
data = inst.vsrc if r4 else inst.data
# RDNA4: vaddr instead of addr, vsrc instead of data
addr = inst.vaddr if r4 else inst.addr # type: ignore[attr-defined]
data = inst.vsrc if r4 else inst.data # type: ignore[attr-defined]
# load_lds_* instructions: vaddr, saddr (no vdst, data goes to LDS)
if 'load_lds' in name:
addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2
@ -351,13 +372,14 @@ def _disasm_flat(inst: FLAT) -> str:
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w)
data_s, vdst_s = reg_fn(data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
if 'atomic' in name:
glc_or_sc0 = inst.sc0 if cdna else inst.glc
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
glc_or_sc0 = inst.sc0 if cdna else inst.glc # type: ignore[attr-defined]
sfx = f"{saddr_s if seg != 'flat' else ''}{mods}"
return f"{instr} {vdst_s}, {addr_s}, {data_s}{sfx}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{sfx}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = inst.op, inst.op_name.lower()
name = inst.op_name.lower()
acc = getattr(inst, 'acc', 0)
reg_fn = _areg if acc else _vreg
gds = " gds" if getattr(inst, 'gds', 0) else ""
@ -386,7 +408,8 @@ def _disasm_ds(inst: DS) -> str:
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}{off2}{gds}"
if 'xchg2' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}, {d0}, {d1}{off2}{gds}"
if 'load' in name or ('read' in name and 'read2' not in name): return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'load' in name or ('read' in name and 'read2' not in name):
return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if ('store' in name or 'write' in name) and not _has(name, 'cmp', 'xchg', 'write2'):
return f"{name} {reg_fn(inst.data0)}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or name == 'ds_ordered_count': return f"{name} {reg_fn(inst.vdst)}, {addr}{off}{gds}"
@ -397,13 +420,15 @@ def _disasm_ds(inst: DS) -> str:
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
def _disasm_vop3(inst: VOP3) -> str:
op, name = inst.op, inst.op_name.lower()
n_up = name.upper()
name = inst.op_name.lower()
bits = inst.canonical_op_bits
# RDNA4 v_s_* scalar VOP3 instructions - vdst is SGPR (VGPRField adds 256)
if name.startswith('v_s_'):
src = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else ("src_scc" if _unwrap(inst.src0) == 253 else _fmt_src(inst.src0, max(1, bits['s0'] // 32)))
s0v = _unwrap(inst.src0)
if s0v == 255: src = _lit(inst, inst.src0)
elif s0v == 253: src = "src_scc"
else: src = _fmt_src(inst.src0, max(1, bits['s0'] // 32))
if inst.neg & 1: src = f"-{src}"
if inst.abs & 1: src = f"|{src}|"
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
@ -412,7 +437,6 @@ def _disasm_vop3(inst: VOP3) -> str:
# Use get_field_bits for register sizes and 16-bit detection
r0, r1, r2 = max(1, bits['s0'] // 32), max(1, bits['s1'] // 32), max(1, bits['s2'] // 32)
dn = max(1, bits['d'] // 32)
is16_d, is16_s, is16_s2 = bits['d'] == 16, bits['s0'] == 16, bits['s2'] == 16
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, r0, is16_s)
@ -428,7 +452,8 @@ def _disasm_vop3(inst: VOP3) -> str:
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
cl, om = " clamp" if clamp else "", _omod(inst.omod)
nonvgpr_opsel = (inst.src0.offset < 256 and (inst.opsel & 1)) or (inst.src1.offset < 256 and (inst.opsel & 2)) or (inst.src2.offset < 256 and (inst.opsel & 4))
nonvgpr_opsel = ((inst.src0.offset < 256 and (inst.opsel & 1)) or (inst.src1.offset < 256 and (inst.opsel & 2))
or (inst.src2.offset < 256 and (inst.opsel & 4)))
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
@ -478,7 +503,7 @@ def _disasm_vopd(inst: VOPD) -> str:
def _disasm_vop3p(inst: VOP3P) -> str:
name = inst.op_name.lower()
is_wmma, is_swmmac, n, is_fma_mix = 'wmma' in name, 'swmmac' in name, inst.num_srcs() or 2, 'fma_mix' in name
is_swmmac, n, is_fma_mix = 'swmmac' in name, inst.num_srcs() or 2, 'fma_mix' in name
def get_src(reg):
return _lit(inst, reg.offset) if reg.offset == 255 else reg.fmt()
src0, src1, src2, dst = get_src(inst.src0), get_src(inst.src1), get_src(inst.src2), inst.vdst.fmt()
@ -487,18 +512,22 @@ def _disasm_vop3p(inst: VOP3P) -> str:
if is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if clamp else [])
mods = (([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else [])
+ ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if clamp else []))
elif is_swmmac:
mods = ([f"index_key:{inst.opsel}"] if inst.opsel else []) + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + \
([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
else:
opsel_hi_default = 7 if n == 3 else 3
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
mods = (([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else [])
+ ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else [])
+ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else [])
+ ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else []))
mod_s = ' ' + ' '.join(mods) if mods else ''
return f"{name} {dst}, {src0}, {src1}, {src2}{mod_s}" if n == 3 else f"{name} {dst}, {src0}, {src1}{mod_s}"
def _disasm_sop1(inst: SOP1) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
name, cdna = inst.op_name.lower(), _is_cdna(inst)
# Use get_field_bits for register sizes
regs = inst.canonical_op_regs
dst_regs, src_regs = regs.get('d', 1), regs.get('s0', 1)
@ -512,8 +541,8 @@ def _disasm_sop1(inst: SOP1) -> str:
try: msg_str = MSG(v).name if v != 255 else None # MSG_RTN_ILLEGAL_MSG (255) not supported by LLVM
except ValueError: msg_str = None
return f"{name} {_fmt_sdst(inst.sdst, dst_regs)}, sendmsg({msg_str})" if msg_str else f"{name} {_fmt_sdst(inst.sdst, dst_regs)}, 0x{v:x}"
sop1_src_only = ('S_ALLOC_VGPR', 'S_SLEEP_VAR', 'S_BARRIER_SIGNAL', 'S_BARRIER_SIGNAL_ISFIRST', 'S_BARRIER_INIT', 'S_BARRIER_JOIN', 'S_SET_GPR_IDX_IDX',
'S_CBRANCH_JOIN')
sop1_src_only = ('S_ALLOC_VGPR', 'S_SLEEP_VAR', 'S_BARRIER_SIGNAL', 'S_BARRIER_SIGNAL_ISFIRST',
'S_BARRIER_INIT', 'S_BARRIER_JOIN', 'S_SET_GPR_IDX_IDX', 'S_CBRANCH_JOIN')
if inst.op_name in sop1_src_only: return f"{name} {src}"
if cdna:
if 'getpc_b64' in name: return f"{name} {_fmt_sdst(inst.sdst, 2, cdna)}"
@ -551,7 +580,7 @@ _HWREG_BLACKLIST_CDNA = {'HW_REG_PC_LO', 'HW_REG_PC_HI', 'HW_REG_IB_DBG1', 'HW_R
'HW_REG_SQ_SHADER_TMA_LO', 'HW_REG_SQ_SHADER_TMA_HI', 'HW_REG_SQ_PERF_SNAPSHOT_DATA', 'HW_REG_SQ_PERF_SNAPSHOT_DATA1',
'HW_REG_SQ_PERF_SNAPSHOT_PC_LO', 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI', 'HW_REG_XCC_ID'}
def _disasm_sopk(inst: SOPK) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
name, cdna = inst.op_name.lower(), _is_cdna(inst)
is_rdna4 = _is_r4(inst)
hw = HWREG_CDNA if cdna else (HWREG_RDNA4 if is_rdna4 else HWREG)
blacklist = _HWREG_BLACKLIST_CDNA if cdna else _HWREG_BLACKLIST
@ -574,12 +603,14 @@ def _disasm_sopk(inst: SOPK) -> str:
def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{inst.op_name.lower()} {inst.vdst.fmt()}, {_lit(inst, inst.src0, inst.neg & 1)}, {_lit(inst, inst.src1, inst.neg & 2)}, {_lit(inst, inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
s0, s1, s2 = _lit(inst, inst.src0, inst.neg & 1), _lit(inst, inst.src1, inst.neg & 2), _lit(inst, inst.src2, inst.neg & 4)
return f"{inst.op_name.lower()} {inst.vdst.fmt()}, {s0}, {s1}, {s2}" + (" " + mods if mods else "")
DISASM_HANDLERS: dict[type, Callable[..., str]] = {
VOP1: _disasm_vop1, VOP1_SDST: _disasm_vop1, VOP1_SDST_LIT: _disasm_vop1, VOP1_LIT: _disasm_vop1,
VOP2: _disasm_vop2, VOP2_LIT: _disasm_vop2, VOPC: _disasm_vopc, VOPC_LIT: _disasm_vopc,
VOP3: _disasm_vop3, VOP3_SDST: _disasm_vop3, VOP3_SDST_LIT: _disasm_vop3, VOP3_LIT: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOP3SD_LIT: _disasm_vop3sd,
VOP3: _disasm_vop3, VOP3_SDST: _disasm_vop3, VOP3_SDST_LIT: _disasm_vop3, VOP3_LIT: _disasm_vop3,
VOP3SD: _disasm_vop3sd, VOP3SD_LIT: _disasm_vop3sd,
VOPD: _disasm_vopd, VOPD_LIT: _disasm_vopd, VOP3P: _disasm_vop3p, VOP3P_LIT: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, GLOBAL: _disasm_flat, SCRATCH: _disasm_flat,
SOP1: _disasm_sop1, SOP1_LIT: _disasm_sop1, SOP2: _disasm_sop2, SOP2_LIT: _disasm_sop2,
@ -634,7 +665,9 @@ def _disasm_vop3a(inst) -> str:
else:
regs = inst.canonical_op_regs
dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
s0 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0)
s1 = _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1)
s2 = _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else _vreg(inst.vdst)
if op_val >= 512:
return f"{name} {dst}, {s0}, {s1}, {s2}{opsel}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{opsel}{cl}{om}"
@ -658,7 +691,9 @@ def _disasm_vop3b(inst) -> str:
n = inst.num_srcs() or _num_srcs(inst)
regs = inst.canonical_op_regs
dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
s0 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0)
s1 = _cdna_src(inst, inst.src1, inst.neg&2, n=r1)
s2 = _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
# CDNA VOP3_SDST uses vdst field for sdst (but vdst adds 256), RDNA uses separate sdst field
sdst_val = getattr(inst, 'sdst', None)
if sdst_val is None and hasattr(inst, 'vdst'):
@ -680,7 +715,7 @@ def _disasm_cdna_vop3p(inst) -> str:
name, n = inst.op_name.lower(), inst.num_srcs() or 2
is_mfma = 'mfma' in name or 'smfmac' in name
is_accvgpr = 'accvgpr' in name
get_src = lambda v, sc: _lit(inst, v) if v == 255 else _fmt_src(v, sc, cdna=True)
def get_src(v, sc): return _lit(inst, v) if v == 255 else _fmt_src(v, sc, cdna=True)
# Handle accvgpr read/write (accumulator register operations)
if is_accvgpr:
@ -742,9 +777,12 @@ def _disasm_cdna_vop3p(inst) -> str:
src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), _vreg(inst.vdst)
opsel_hi = inst.opsel_hi # CDNA VOP3P only has 2 bits for opsel_hi (no opsel_hi2)
opsel_hi_default = 3 # CDNA default is 0b11 (2 bits), not 0b111 like RDNA
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
mods = (([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else [])
+ ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else [])
+ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else [])
+ ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []))
mod_s = ' ' + ' '.join(mods) if mods else ''
return f"{name} {dst}, {src0}, {src1}, {src2}{mod_s}" if n == 3 else f"{name} {dst}, {src0}, {src1}{mod_s}"
def _disasm_mubuf(inst) -> str:
name = inst.op_name.lower()
@ -903,5 +941,6 @@ DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP1_LIT: _disasm_vop1,
CDNA_SOP1: _disasm_sop1, CDNA_SOP1_LIT: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOP2_LIT: _disasm_sop2,
CDNA_SOPC: _disasm_sopc, CDNA_SOPC_LIT: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPK_LIT: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_GLOBAL: _disasm_flat, CDNA_SCRATCH: _disasm_flat,
CDNA_VOP3: _disasm_vop3a, CDNA_VOP3_SDST: _disasm_vop3b, CDNA_VOP3SD: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, CDNA_VOP3P_MFMA: _disasm_cdna_vop3p,
CDNA_VOP3: _disasm_vop3a, CDNA_VOP3_SDST: _disasm_vop3b, CDNA_VOP3SD: _disasm_vop3b,
CDNA_VOP3P: _disasm_cdna_vop3p, CDNA_VOP3P_MFMA: _disasm_cdna_vop3p,
CDNA_MUBUF: _disasm_mubuf, CDNA_VOP3PX2: _disasm_vop3px2})

View file

@ -47,7 +47,7 @@ def get_gpu_target() -> tuple[int, int, int]:
"""Get the GPU target as (major, minor, stepping) tuple."""
if not USE_HW: return (0, 0, 0)
from tinygrad.device import Device
return Device["AMD"].target
return Device["AMD"].target # type: ignore[attr-defined]
def skip_unless_gfx(min_major: int, min_minor: int = 0, reason: str = ""):
"""Skip test if GPU target is below the minimum required version."""
@ -171,7 +171,7 @@ def run_program_hw(instructions: list, n_lanes: int = 1) -> WaveState:
from tinygrad.helpers import flat_mv
dev = Device["AMD"]
compiler = HIPCompiler(dev.arch)
compiler = HIPCompiler(dev.arch) # type: ignore[attr-defined]
prologue, epilogue = get_prologue_epilogue(n_lanes)
code = assemble(prologue + instructions + epilogue)
@ -218,7 +218,7 @@ amdhsa.kernels:
"""
lib = compiler.compile(asm_src)
prg = AMDProgram(dev, "test", lib)
prg = AMDProgram(dev, "test", lib) # type: ignore[arg-type]
out_gpu = dev.allocator.alloc(OUT_BYTES)
assert out_gpu.va_addr % 16 == 0, f"buffer not 16-byte aligned: 0x{out_gpu.va_addr:x}"
@ -276,6 +276,6 @@ def run_program(instructions: list, n_lanes: int = 1, ulp_tolerance: int = 0) ->
hw_st = run_program_hw(instructions, n_lanes)
diffs = compare_wave_states(emu_st, hw_st, n_lanes, ulp_tolerance=ulp_tolerance)
if diffs:
raise AssertionError(f"Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
raise AssertionError("Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
return hw_st
return emu_st

View file

@ -601,7 +601,6 @@ class TestDS2AddrStride64(unittest.TestCase):
self.assertEqual(st.vgpr[0][6], 0xAAAAAAAA, "new val 0")
self.assertEqual(st.vgpr[0][7], 0xBBBBBBBB, "new val 1")
def test_ds_storexchg_rtn_b64(self):
"""DS_STOREXCHG_RTN_B64: exchange 64-bit value and return old."""
instructions = [

View file

@ -373,7 +373,6 @@ class TestF64Conversions(unittest.TestCase):
def test_v_cvt_f64_f32_pi(self):
"""V_CVT_F64_F32 converts f32 pi to f64."""
import math
instructions = [
s_mov_b32(s[0], f2i(3.14159265)),
v_mov_b32_e32(v[0], s[0]),

View file

@ -725,7 +725,7 @@ class TestLaneOps(unittest.TestCase):
# v[5] should have the value only in lane 1
for lane in range(4):
if lane == 1:
self.assertEqual(st.vgpr[lane][5], 0x12345678, f"v[5] lane 1 should have 0x12345678")
self.assertEqual(st.vgpr[lane][5], 0x12345678, "v[5] lane 1 should have 0x12345678")
else:
self.assertEqual(st.vgpr[lane][5], 0, f"v[5] lane {lane} should be 0")
@ -1082,7 +1082,6 @@ class TestF64Ops(unittest.TestCase):
"""Full f64->i64 conversion sequence with negative value."""
import struct
val = f2i64(-8.0)
lit = 0xC1F00000 # high 32 bits of f64 -2^32
instructions = [
s_mov_b32(s[0], val & 0xffffffff),
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
@ -1138,7 +1137,6 @@ class TestF64Ops(unittest.TestCase):
# v_fma_f64 v[7:8], v[17:18], v[7:8], v[15:16]
# We need to capture the exact input values and verify output matches hardware
# v[7:8] before = 0x3f80fdf3_d69db28f (0.008296875941334462)
v78 = 0x3f80fdf3d69db28f
# For the FMA to produce 0xbf457ef0_ab8c254d, we need v[17:18] and v[15:16]
# Let's test with known precision-sensitive values
a = 1.0000000001
@ -1395,7 +1393,7 @@ class TestWMMAMore(unittest.TestCase):
def test_v_wmma_f32_16x16x16_f16_basic(self):
"""V_WMMA_F32_16X16X16_F16 basic test - verify output is non-zero."""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00))
for i in range(16, 32):
instructions.append(v_mov_b32_e32(v[i], s[0]))
@ -1851,7 +1849,6 @@ class TestMed3(unittest.TestCase):
def test_v_med3_f32_with_nan(self):
"""V_MED3_F32: NaN handling - returns min of non-NaN values."""
import math
instructions = [
s_mov_b32(s[0], 0x7fc00000), # NaN
v_mov_b32_e32(v[0], s[0]),
@ -2490,7 +2487,6 @@ class TestDivScaleF64(unittest.TestCase):
independently. This catches the bug where the emulator was setting VCC
for all lanes to the same value.
"""
import math
# Use lane-varying input: lane 0 gets 2.0, lane 1 gets 3.0, etc.
# All normal values should result in VCC=0 for each lane
instructions = [
@ -2721,7 +2717,6 @@ class TestDivScaleFmasF64Integration(unittest.TestCase):
This is the exact bug scenario: tan([2.0, 3.0, 4.0]) was failing because
VCC from DIV_SCALE was being set incorrectly for all lanes.
"""
import math
# Set up values like tan() would: different values per lane
instructions = [
# Create per-lane values: 2.0, 3.0, 4.0, 5.0

View file

@ -418,7 +418,7 @@ class TestWMMAF16(unittest.TestCase):
def test_v_wmma_f16_16x16x16_f16_all_ones(self):
"""V_WMMA_F16_16X16X16_F16 with all ones produces 16.0 in f16."""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
# Initialize A matrix in v[16:23] (8 regs)
for i in range(16, 24):
@ -442,7 +442,7 @@ class TestWMMAF16(unittest.TestCase):
def test_v_wmma_f16_16x16x16_f16_with_accumulator(self):
"""V_WMMA_F16_16X16X16_F16 with non-zero accumulator."""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
instructions.append(s_mov_b32(s[1], 0x4500)) # f16 5.0 in lo bits only
# Initialize A matrix in v[16:23] (8 regs)
@ -471,7 +471,7 @@ class TestWMMAF16(unittest.TestCase):
Regression test: WMMA was using static register indices instead of dynamic.
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
"""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
# Initialize A matrix in v[64:71] (8 regs)
for i in range(64, 72):
@ -502,7 +502,7 @@ class TestWMMA(unittest.TestCase):
def test_v_wmma_f32_16x16x16_f16_all_ones(self):
"""V_WMMA_F32_16X16X16_F16 with all ones produces 16.0."""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
for i in range(16, 32):
instructions.append(v_mov_b32_e32(v[i], s[0]))
@ -518,7 +518,7 @@ class TestWMMA(unittest.TestCase):
def test_v_wmma_f32_16x16x16_f16_with_accumulator(self):
"""V_WMMA_F32_16X16X16_F16 with non-zero accumulator."""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00))
instructions.append(s_mov_b32(s[1], f2i(5.0)))
for i in range(16, 32):
@ -540,7 +540,7 @@ class TestWMMA(unittest.TestCase):
causing incorrect results when registers weren't at the default positions.
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
"""
instructions = []
instructions: list[Inst] = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
# Initialize A matrix in v[64:71]
for i in range(64, 72):
@ -569,7 +569,7 @@ class TestWMMABF16(unittest.TestCase):
def test_v_wmma_f32_16x16x16_bf16_all_ones(self):
"""V_WMMA_F32_16X16X16_BF16 with all ones produces 16.0."""
instructions = []
instructions: list[Inst] = []
# BF16 1.0 = 0x3f80, packed = 0x3f803f80
instructions.append(s_mov_b32(s[0], 0x3f803f80))
for i in range(16, 32):
@ -586,7 +586,7 @@ class TestWMMABF16(unittest.TestCase):
def test_v_wmma_f32_16x16x16_bf16_with_accumulator(self):
"""V_WMMA_F32_16X16X16_BF16 with non-zero accumulator."""
instructions = []
instructions: list[Inst] = []
# BF16 1.0 = 0x3f80, packed = 0x3f803f80
instructions.append(s_mov_b32(s[0], 0x3f803f80))
instructions.append(s_mov_b32(s[1], f2i(5.0)))

View file

@ -7,8 +7,7 @@ VOPD executes two operations simultaneously. Key behavior:
- Op Y can use ops 0-18 (includes ADD_NC_U32, LSHLREV, AND)
"""
import unittest
from extra.assembly.amd.test.hw.helpers import run_program, run_program_emu, run_program_hw, compare_wave_states, \
v, s, v_mov_b32_e32, s_mov_b32
from extra.assembly.amd.test.hw.helpers import run_program, v, v_mov_b32_e32
from extra.assembly.amd.autogen.rdna3.ins import VOPD, VOPD_LIT, VOPDOp
class TestVOPDBasic(unittest.TestCase):

View file

@ -81,7 +81,9 @@ class RustEmulator:
return snap.to_snapshot()
def free(self):
if self.ctx: self.lib.wave_free(self.ctx); self.ctx = None
if self.ctx:
self.lib.wave_free(self.ctx)
self.ctx = None
class PythonEmulator:
def __init__(self):
@ -114,8 +116,8 @@ class PythonEmulator:
if pc == 0xFFFFFFFFFFFFFFFF or pc not in self.program: return -1
name, fxn, globals_list, _runner = self.program[pc]
if fxn is None: return 1 # unsupported instruction
buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr,
2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr}
buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr, # type: ignore[union-attr]
2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr} # type: ignore[union-attr]
# Direct ctypes call - bypasses HCQ overhead
fxn(*[ctypes.c_uint64(buf_addrs[g]) for g in globals_list], ctypes.c_int32(0))
return -1 if self.state.pc == 0xFFFFFFFFFFFFFFFF else 0
@ -178,6 +180,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
rust_before = rust.get_snapshot()
python_before = python.get_snapshot()
assert python.program is not None
inst_info = python.program.get(python.lib_addr + python_before.pc * 4) # Convert word offset to actual address
inst_hex_name = inst_info[0] if inst_info else f"unknown at PC={python_before.pc}"
# Decode the instruction to get mnemonic for sync_after checks
@ -188,7 +191,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
inst_bytes = bytes.fromhex(inst_bytes_hex) if inst_bytes_hex else b''
decoded = decode_inst(inst_bytes) if inst_bytes else None
inst_mnemonic = repr(decoded).split('(')[0] if decoded else ""
except:
except Exception:
inst_mnemonic = ""
# For generic instructions, use function name for sync_after check
if not inst_mnemonic: inst_mnemonic = inst_hex_name
@ -220,16 +223,18 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
python_diffs = pb.diff(next_pb, n_lanes, "->")
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
elif rust_diffs: trace_lines.append(f" python: (no changes)")
elif rust_diffs: trace_lines.append(" python: (no changes)")
else:
# Last traced instruction - compare with current state
rust_diffs = rb.diff(rust_before, n_lanes, "->")
python_diffs = pb.diff(python_before, n_lanes, "->")
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
elif rust_diffs: trace_lines.append(f" python: (no changes)")
elif rust_diffs: trace_lines.append(" python: (no changes)")
trace_str = "\n".join(trace_lines)
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n " + "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}", total_steps
msg = f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n "
msg += "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}"
return False, msg, total_steps
rust_result = rust.step()
python_result = python.step()
@ -239,7 +244,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
if rust_result == 1 and python_result == 0:
raise unittest.SkipTest(f"Rust emulator doesn't support instruction: {inst_str}")
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
msg = (f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: "
f"rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}")
return False, msg, total_steps
# Sync Python state to Rust after instructions with known Rust emulator differences
if sync_after:
@ -429,7 +436,8 @@ class TestTinygradKernels(unittest.TestCase):
def test_cast(self): self._test_kernel(lambda T: T.empty(32).half().float() + T.empty(32).int().float())
# Pooling - regression for VCC wave32 mode
def test_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
def test_pool2d(self):
self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
# Convolution
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 2, 8, 8).conv2d(T.empty(2, 2, 3, 3)), max_steps=50000)

View file

@ -58,7 +58,7 @@ def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp:
class TestCustomKernel(unittest.TestCase):
def test_simple(self):
a = Tensor.full((16, 16), 1.).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].renderer.arch))[0]
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
ei = a.schedule()[-1].lower()
self.assertEqual(ei.prg.estimates.ops, a.numel())
self.assertEqual(ei.prg.estimates.mem, a.nbytes()*2)
@ -68,7 +68,7 @@ class TestCustomKernel(unittest.TestCase):
def test_variable(self):
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
a = Tensor.zeros_like(b).contiguous().realize()
a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].renderer.arch))[0]
a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
ei = a.schedule()[-1].lower()
for i in range(4):
ei.run({"var":i})

View file

@ -7,11 +7,11 @@ from tinygrad.uop.ops import UOp, Ops
from extra.assembly.amd.emu import parse_pcode
from extra.assembly.amd.pcode import parse_expr
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
from extra.assembly.amd.autogen.rdna3.enum import VOP1Op, VOP2Op, VOP3Op, SOP1Op, SOP2Op, DSOp
from extra.assembly.amd.autogen.rdna3.enum import VOP1Op, VOP2Op, SOP2Op, DSOp
def _srcs():
"""Create minimal source variables for pcode parsing."""
u32 = lambda v=0: UOp.const(dtypes.uint32, v)
def u32(v=0): return UOp.const(dtypes.uint32, v)
return {'S0': u32(), 'S1': u32(), 'S2': u32(), 'SCC': u32(), 'VCC': UOp.const(dtypes.uint64, 0), 'laneId': u32()}
class TestBasicParsing(unittest.TestCase):
@ -90,16 +90,16 @@ class TestParseExpr(unittest.TestCase):
def test_variable_lookup(self):
"""Test variable lookup in parse_expr."""
vars = {'x': UOp.const(dtypes.uint32, 42)}
result = parse_expr('x', vars)
vrs = {'x': UOp.const(dtypes.uint32, 42)}
result = parse_expr('x', vrs)
self.assertEqual(result.arg, 42)
def test_binary_ops(self):
"""Test parsing binary operations."""
vars = {'a': UOp.const(dtypes.uint32, 10), 'b': UOp.const(dtypes.uint32, 5)}
vrs = {'a': UOp.const(dtypes.uint32, 10), 'b': UOp.const(dtypes.uint32, 5)}
# Addition
result = parse_expr('a + b', vars)
result = parse_expr('a + b', vrs)
self.assertEqual(result.op, Ops.ADD)
# Subtraction with constant folding
@ -109,8 +109,8 @@ class TestParseExpr(unittest.TestCase):
def test_ternary(self):
"""Test parsing ternary expressions."""
vars = {'cond': UOp.const(dtypes.bool, True), 'a': UOp.const(dtypes.uint32, 1), 'b': UOp.const(dtypes.uint32, 0)}
result = parse_expr('cond ? a : b', vars)
vrs = {'cond': UOp.const(dtypes.bool, True), 'a': UOp.const(dtypes.uint32, 1), 'b': UOp.const(dtypes.uint32, 0)}
result = parse_expr('cond ? a : b', vrs)
self.assertEqual(result.op, Ops.WHERE)
class TestForLoopParsing(unittest.TestCase):
@ -120,13 +120,14 @@ class TestForLoopParsing(unittest.TestCase):
"""Verify CLZ pcode is available."""
pcode = PCODE.get(VOP1Op.V_CLZ_I32_U32_E32)
self.assertIsNotNone(pcode)
assert pcode is not None
self.assertIn('for', pcode.lower())
def test_clz_parsing(self):
"""Test CLZ pcode parsing produces correct structure."""
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
S0 = UOp.const(dtypes.uint32, 0xFFFFFFFF) # All ones - CLZ should be 0
vars, assigns = parse_pcode(pcode, {'S0': S0})
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
self.assertEqual(len(assigns), 1)
dest, val = assigns[0]
@ -138,7 +139,7 @@ class TestForLoopParsing(unittest.TestCase):
"""Test CLZ with input 0 - should return -1."""
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
S0 = UOp.const(dtypes.uint32, 0)
vars, assigns = parse_pcode(pcode, {'S0': S0})
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
# Check that the innermost value (default) is -1 (may be wrapped in CAST)
val = assigns[0][1]
@ -157,7 +158,7 @@ class TestForLoopParsing(unittest.TestCase):
self.skipTest("V_CTZ_I32_B32_E32 pcode not available")
S0 = UOp.const(dtypes.uint32, 1) # LSB set - CTZ should be 0
vars, assigns = parse_pcode(pcode, {'S0': S0})
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
self.assertEqual(len(assigns), 1)
class TestDSPcodePatterns(unittest.TestCase):
@ -167,6 +168,7 @@ class TestDSPcodePatterns(unittest.TestCase):
"""Test DS_LOAD_B32 pcode is parseable."""
pcode = PCODE.get(DSOp.DS_LOAD_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
self.assertIn('RETURN_DATA', pcode)
self.assertIn('MEM[', pcode)
@ -174,6 +176,7 @@ class TestDSPcodePatterns(unittest.TestCase):
"""Test DS_STORE_B32 pcode is parseable."""
pcode = PCODE.get(DSOp.DS_STORE_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
self.assertIn('MEM[', pcode)
self.assertIn('DATA', pcode)
@ -182,9 +185,9 @@ class TestDSPcodePatterns(unittest.TestCase):
# Create a mock LDS buffer
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
addr = UOp.const(dtypes.uint32, 0)
vars = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
vrs = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
result = parse_expr('MEM[ADDR + OFFSET].b32', vars)
result = parse_expr('MEM[ADDR + OFFSET].b32', vrs)
# Should be an INDEX operation into LDS
self.assertIsNotNone(result)
@ -192,6 +195,7 @@ class TestDSPcodePatterns(unittest.TestCase):
"""Test DS_STORE_2ADDR_B32 pcode parsing produces MEM writes."""
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
'OFFSET0': UOp.const(dtypes.uint32, 0),
@ -207,12 +211,13 @@ class TestDSPcodePatterns(unittest.TestCase):
self.assertTrue(dest.startswith('MEM['))
# val should be (addr, write_val) tuple
self.assertIsInstance(val, tuple)
self.assertEqual(len(val), 2)
self.assertEqual(len(val), 2) # type: ignore[arg-type]
def test_ds_load_2addr_b32_parsing(self):
"""Test DS_LOAD_2ADDR_B32 pcode parsing produces RETURN_DATA assignments."""
pcode = PCODE.get(DSOp.DS_LOAD_2ADDR_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
@ -230,6 +235,7 @@ class TestDSPcodePatterns(unittest.TestCase):
def test_ds_store_address_calculation(self):
"""Test DS_STORE_2ADDR_B32 calculates correct addresses (offset * 4)."""
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
assert pcode is not None
srcs = {
'ADDR': UOp.const(dtypes.uint32, 100),
'OFFSET0': UOp.const(dtypes.uint32, 2),
@ -240,14 +246,14 @@ class TestDSPcodePatterns(unittest.TestCase):
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
_, assigns = parse_pcode(pcode, srcs)
# Check addresses: 100 + 2*4 = 108, 100 + 5*4 = 120
addr0, _ = assigns[0][1]
addr1, _ = assigns[1][1]
self.assertEqual(addr0.simplify().arg, 108)
self.assertEqual(addr1.simplify().arg, 120)
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
self.assertEqual(assigns[0][1][0].simplify().arg, 108) # type: ignore[index]
self.assertEqual(assigns[1][1][0].simplify().arg, 120) # type: ignore[index]
def test_ds_store_data_values(self):
"""Test DS_STORE_2ADDR_B32 uses correct data values."""
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
assert pcode is not None
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
'OFFSET0': UOp.const(dtypes.uint32, 0),
@ -257,11 +263,10 @@ class TestDSPcodePatterns(unittest.TestCase):
}
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
_, assigns = parse_pcode(pcode, srcs)
_, val0 = assigns[0][1]
_, val1 = assigns[1][1]
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
# DATA[31:0] should preserve the value
self.assertEqual(val0.simplify().arg, 0xAAAAAAAA)
self.assertEqual(val1.simplify().arg, 0xBBBBBBBB)
self.assertEqual(assigns[0][1][1].simplify().arg, 0xAAAAAAAA) # type: ignore[index]
self.assertEqual(assigns[1][1][1].simplify().arg, 0xBBBBBBBB) # type: ignore[index]
class TestConditionalParsing(unittest.TestCase):
"""Test conditional (if/elsif/else) pcode parsing."""
@ -273,7 +278,7 @@ class TestConditionalParsing(unittest.TestCase):
s0 = UOp.const(dtypes.uint32, 10)
s1 = UOp.const(dtypes.uint32, 20)
scc = UOp.const(dtypes.uint32, 1)
vars, assigns = parse_pcode(pcode, {'S0': s0, 'S1': s1, 'SCC': scc})
_vrs, assigns = parse_pcode(pcode, {'S0': s0, 'S1': s1, 'SCC': scc})
self.assertEqual(len(assigns), 1)
dest, val = assigns[0]
self.assertTrue(dest.startswith('D0'))
@ -294,7 +299,8 @@ class TestAllPcode(unittest.TestCase):
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(), 'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(),
'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}
@ -306,7 +312,9 @@ class TestAllPcode(unittest.TestCase):
try:
parse_pcode(pcode, srcs)
passed += 1
except RuntimeError as e: skipped += 1; errors[str(e)].append(op.name)
except RuntimeError as e:
skipped += 1
errors[str(e)].append(op.name)
except Exception as e: self.fail(f"[{arch}] {op.name}: {e}\nPcode: {pcode[:200]}")
total = len(pcode_dict)
pct = 100 * passed / total

View file

@ -127,15 +127,17 @@ def _make_test(f: str, arch: str, test_type: str):
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")
elif test_type == "repr":
# Test that eval(repr(inst)) reproduces the instruction
if arch == "rdna3": import extra.assembly.amd.autogen.rdna3.ins as ins
elif arch == "rdna4": import extra.assembly.amd.autogen.rdna4.ins as ins
elif arch == "cdna": import extra.assembly.amd.autogen.cdna.ins as ins
if arch == "rdna3": import extra.assembly.amd.autogen.rdna3.ins as ins # type: ignore[no-redef]
elif arch == "rdna4": import extra.assembly.amd.autogen.rdna4.ins as ins # type: ignore[no-redef]
elif arch == "cdna": import extra.assembly.amd.autogen.cdna.ins as ins # type: ignore[no-redef]
ns = {k: getattr(ins, k) for k in dir(ins) if not k.startswith('_')}
passed, skipped = 0, 0
for _, data in tests:
try:
decoded = detect_format(data, arch).from_bytes(data)
if decoded.to_bytes()[:len(data)] != data: skipped += 1; continue # skip if binary roundtrip fails
if decoded.to_bytes()[:len(data)] != data:
skipped += 1
continue # skip if binary roundtrip fails
r = repr(decoded)
try:
decoded2 = eval(r, ns) # noqa: S307
@ -153,7 +155,7 @@ def _make_test(f: str, arch: str, test_type: str):
enc = decoded.to_bytes()[:len(data)]
# Skip if roundtrip fails, disasm fails, or op_name is missing (disasm starts with space)
if enc == data and (d := disasm(decoded)) and not d.startswith(' '): to_test.append((enc, d))
except: pass
except Exception: pass
skipped = len(tests) - len(to_test)
print(f"{name}: {len(to_test)} passed, {skipped} skipped")
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")

View file

@ -6,6 +6,10 @@ from extra.assembly.amd.generate import extract_pdf_text, extract_pcode, parse_x
EXPECTED_PAGES = {"rdna3": 655, "rdna4": 711, "cdna": 610}
class TestPcodePDF(unittest.TestCase):
pages: dict
enums: dict
pcode: dict
@classmethod
def setUpClass(cls):
cls.pages = {arch: extract_pdf_text(cfg["pdf"]) for arch, cfg in ARCHS.items()}
@ -33,7 +37,8 @@ class TestPcodePDF(unittest.TestCase):
'tmp = MEM[ADDR].u64;\nsrc = DATA.u64;\nMEM[ADDR].u64 = src >= tmp ? src : tmp;\nRETURN_DATA.u64 = tmp')
# GLOBAL_STORE_B128: should have 4 MEM stores (not truncated)
self.assertEqual(pcode[('GLOBAL_STORE_B128', 29)],
'MEM[ADDR].b32 = VDATA[31 : 0];\nMEM[ADDR + 4U].b32 = VDATA[63 : 32];\nMEM[ADDR + 8U].b32 = VDATA[95 : 64];\nMEM[ADDR + 12U].b32 = VDATA[127 : 96]')
'MEM[ADDR].b32 = VDATA[31 : 0];\nMEM[ADDR + 4U].b32 = VDATA[63 : 32];\n'
'MEM[ADDR + 8U].b32 = VDATA[95 : 64];\nMEM[ADDR + 12U].b32 = VDATA[127 : 96]')
# S_CMOVK_I32: should have full if/endif block
self.assertEqual(pcode[('S_CMOVK_I32', 2)],
"if SCC then\nD0.i32 = 32'I(signext(SIMM16.i16))\nendif")

View file

@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec
from tinygrad.dtype import dtypes
class TestRDNA4Emu(unittest.TestCase):
def _run(self, insts: list, sgprs: dict[int, int] = None, vgprs: dict[tuple[int, int], int] = None) -> WaveState:
def _run(self, insts: list, sgprs: dict[int, int] | None = None, vgprs: dict[tuple[int, int], int] | None = None) -> WaveState:
"""Run instructions and return final WaveState."""
# Add S_ENDPGM if not present
if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts):
@ -22,10 +22,8 @@ class TestRDNA4Emu(unittest.TestCase):
# Setup wave state
st = WaveState(n_lanes=1)
st.pc = code_addr
if sgprs:
for idx, val in sgprs.items(): st._write_sgpr(idx, val)
if vgprs:
for (reg, lane), val in vgprs.items(): st._write_vgpr(reg, lane, val)
for idx, val in (sgprs or {}).items(): st._write_sgpr(idx, val)
for (reg, lane), val in (vgprs or {}).items(): st._write_vgpr(reg, lane, val)
# Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access)
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()

View file

@ -1,8 +1,7 @@
#!/usr/bin/env python3
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd import decode_inst, detect_format
from extra.assembly.amd import detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr
from extra.assembly.amd.test.disasm import disasm
@ -100,11 +99,6 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
while offset < len(code):
remaining = code[offset:]
fmt = detect_format(remaining, arch)
if fmt is None:
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
offset += 4
continue
base_size = fmt._size()
if len(remaining) < base_size:
break
@ -190,8 +184,8 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
print(f"[{arch}] decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"[{arch}] asm via llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"[{arch}] disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
self.assertEqual(decode_failed, 0, "Decode failures:\n" + "\n".join(decode_failures[:20]))
self.assertEqual(asm_failed, 0, "Asm failures:\n" + "\n".join(asm_failures[:20]))
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
# Basic unary ops

View file

@ -8,8 +8,9 @@ from tinygrad.runtime.support.elf import elf_loader
from extra.assembly.amd import decode_inst
from extra.assembly.amd.autogen.rdna3.ins import SOPP
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
ALUEXEC, VMEMEXEC, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, InstOp, InstOpRDNA4, print_packets)
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4,
InstOp, InstOpRDNA4, print_packets)
from extra.assembly.amd.test.helpers import TARGET_TO_ARCH
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
@ -32,18 +33,18 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
assert text is not None, "no .text section found"
text_off, text_size = text.header.sh_addr, text.header.sh_size
blob_iter, current_blob = iter(blobs), [None]
blob_iter, current_blob = iter(blobs), [None] # type: ignore[var-annotated]
occupancy_records: list[tuple[int, int, int, int, bool]] = [] # (wave_id, simd, cu, time, is_start)
wave_insts: list[list[tuple[int, int]]] = [] # per-wave list of (time, stall)
@rocprof.rocprof_trace_decoder_se_data_callback_t
def copy_cb(buf, buf_size, _):
def copy_cb(buf, buf_size, _): # type: ignore[no-untyped-def]
blob = next(blob_iter, None)
if blob is None: return 0
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob)
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte))
buf_size[0] = len(current_blob[0])
return len(current_blob[0])
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob) # type: ignore[call-overload]
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte)) # type: ignore[arg-type]
buf_size[0] = len(current_blob[0]) # type: ignore[arg-type]
return len(current_blob[0]) # type: ignore[arg-type]
@rocprof.rocprof_trace_decoder_trace_callback_t
def trace_cb(record_type, events_ptr, n, _):
@ -94,6 +95,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
class SQTTExamplesTestBase(unittest.TestCase):
target: str
examples: dict
@classmethod
def setUpClass(cls):
@ -115,7 +117,9 @@ class SQTTExamplesTestBase(unittest.TestCase):
for i, event in enumerate(events):
with self.subTest(example=name, event=i):
packets = list(decode(event.blob))
if DEBUG >= 2: print(f"\n=== {name} event {i} ==="); print_packets(packets)
if DEBUG >= 2:
print(f"\n=== {name} event {i} ===")
print_packets(packets)
self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}")
self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}")

View file

@ -94,12 +94,13 @@ def extract_cdna_packet_sizes():
rw_base, rw_offset = _find_segment('rw-p')
if not (head := ctypes.c_void_p.from_address(rw_base + (0x2d4f0 - rw_offset)).value if rw_base else None): return None
pkt_sizes, node, seen = {}, head, set()
pkt_sizes: dict[int, int] = {}
node, seen = head, set()
while node and node not in seen and len(pkt_sizes) < 20:
seen.add(node)
key, val = ctypes.c_uint32.from_address(node + 8).value, ctypes.c_uint32.from_address(node + 12).value
if key < 16 and val in (0x10, 0x20, 0x30, 0x40): pkt_sizes[key] = {0x10: 2, 0x20: 4, 0x30: 6, 0x40: 8}[val]
node = ctypes.c_void_p.from_address(node).value
node = ctypes.c_void_p.from_address(node).value # type: ignore[assignment]
return pkt_sizes if len(pkt_sizes) == 16 else None
# ═══════════════════════════════════════════════════════════════════════════════
@ -127,14 +128,14 @@ class TestSQTTMatchesBinary(unittest.TestCase):
for pkt_fmt, pkt_cls in PACKET_TYPES_CDNA.items():
with self.subTest(packet=pkt_cls.__name__):
self.assertEqual(pkt_cls.encoding.default, pkt_fmt)
self.assertEqual(CDNA_PKT_SIZES[pkt_fmt] * 2, pkt_cls._size_nibbles)
self.assertEqual(CDNA_PKT_SIZES[pkt_fmt] * 2, pkt_cls._size_nibbles) # type: ignore[attr-defined]
def _test_bit_counts(self, layout: int):
if not (tables := extract_bit_tables()): self.skipTest("rocprof-trace-decoder not installed")
from extra.assembly.amd.sqtt import PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4
for type_id, pkt_cls in {3: PACKET_TYPES_RDNA3, 4: PACKET_TYPES_RDNA4}[layout].items():
with self.subTest(packet=pkt_cls.__name__):
self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id])
self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id]) # type: ignore[attr-defined]
def _test_encodings(self, layout: int):
if not (encodings := extract_packet_encodings()): self.skipTest("rocprof-trace-decoder not installed")
@ -164,14 +165,16 @@ if __name__ == "__main__":
print("L2:", tables[0], "\nL3:", tables[1], "\nL4:", tables[2])
if encodings and tables:
print(f"\n{'TypeID':>6} {'Name':>18} {'L2 enc':>12} {'L3 enc':>12} {'L4 enc':>12} {'L2':>4} {'L3':>4} {'L4':>4} {'L2 delta':>12} {'L3 delta':>12} {'L4 delta':>12}")
print(f"\n{'TypeID':>6} {'Name':>18} {'L2 enc':>12} {'L3 enc':>12} {'L4 enc':>12}"
f" {'L2':>4} {'L3':>4} {'L4':>4} {'L2 delta':>12} {'L3 delta':>12} {'L4 delta':>12}")
print("-" * 140)
for type_id in sorted(set(encodings[0]) | set(encodings[1]) | set(encodings[2])):
name = TYPE_NAMES.get(type_id, f'UNK_{type_id}')
bits = [tables[i][type_id] if type_id < len(tables[i]) else 0 for i in range(3)]
enc_strs = [f"0x{encodings[i][type_id][0]:02x}/0x{encodings[i][type_id][1]:02x}" if type_id in encodings[i] else "-" for i in range(3)]
delta_strs = [f"[{d[1]-1}:{d[0]}]" if (d := deltas[i].get(type_id, (0, 0)))[1] > d[0] else "-" for i in range(3)]
print(f"{type_id:6d} {name:>18} {enc_strs[0]:>12} {enc_strs[1]:>12} {enc_strs[2]:>12} {bits[0]:4d} {bits[1]:4d} {bits[2]:4d} {delta_strs[0]:>12} {delta_strs[1]:>12} {delta_strs[2]:>12}")
print(f"{type_id:6d} {name:>18} {enc_strs[0]:>12} {enc_strs[1]:>12} {enc_strs[2]:>12}"
f" {bits[0]:4d} {bits[1]:4d} {bits[2]:4d} {delta_strs[0]:>12} {delta_strs[1]:>12} {delta_strs[2]:>12}")
cdna = extract_cdna_packet_sizes()
if cdna: print(f"\nCDNA packet sizes: {cdna}")

View file

@ -46,6 +46,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
class TestSQTTMapBase(unittest.TestCase):
target: str
examples: dict
@classmethod
def setUpClass(cls):

View file

@ -128,7 +128,7 @@ debug = true
[tool.mypy]
warn_unused_configs = true
files = ["tinygrad"]
files = ["tinygrad", "extra/assembly/amd"]
ignore_missing_imports = true
check_untyped_defs = true
explicit_package_bases = true
@ -142,6 +142,10 @@ strict_equality = true
module = "extra.*"
follow_imports = "skip"
[[tool.mypy.overrides]]
module = "extra.assembly.amd.*"
follow_imports = "normal"
[tool.pytest.ini_options]
norecursedirs = [
"extra",
@ -180,6 +184,7 @@ exclude = [
".git/",
"docs/",
"extra/",
"!extra/assembly/amd/",
"test/external/mlperf_resnet",
"test/external/mlperf_unet3d",
]
@ -245,6 +250,8 @@ select = [
"F841",
]
"tinygrad/runtime/autogen/**/*.py" = ["E501", "F401", "E722", "E731", "F821", "A006", "A002", "F811"]
"extra/assembly/amd/autogen/**/*.py" = ["E501"]
"extra/assembly/amd/test/**/*.py" = ["F403", "F405"]
[tool.ruff.format]
exclude = ["*"]

View file

@ -325,7 +325,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]:
name, width = (op_name, 10 if "BARRIER" in op_name else 1)
add(name, p, width=width, idx=int("OTHER" in name), info=info)
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info)
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info) # type: ignore[union-attr]
if isinstance(p, (VMEMEXEC, ALUEXEC)):
name = str(p.src).split('.')[1]
if name == "VALU_SALU":