more handwritten

This commit is contained in:
George Hotz 2025-12-25 16:38:12 -05:00
commit d41bb12a13
4 changed files with 131 additions and 17 deletions

View file

@ -22,8 +22,11 @@ def _vreg(base: int, cnt: int = 1) -> str: return f"v{base}" if cnt == 1 else f"
def _fmt_sdst(v: int, cnt: int = 1) -> str:
"""Format SGPR destination with special register names."""
if v == 124: return "null"
if 108 <= v <= 123: return f"ttmp[{v-108}:{v-108+cnt-1}]" if cnt == 2 else f"ttmp{v-108}"
if cnt == 2: return "exec" if v == 126 else "vcc" if v == 106 else _sreg(v, 2)
if 108 <= v <= 123: return f"ttmp[{v-108}:{v-108+cnt-1}]" if cnt > 1 else f"ttmp{v-108}"
if cnt > 1:
if v == 126 and cnt == 2: return "exec"
if v == 106 and cnt == 2: return "vcc"
return _sreg(v, cnt)
return {126: "exec_lo", 127: "exec_hi", 106: "vcc_lo", 107: "vcc_hi", 125: "m0"}.get(v, f"s{v}")
def _fmt_ssrc(v: int, cnt: int = 1) -> str:
@ -152,10 +155,47 @@ def disasm(inst: Inst) -> str:
# SMEM
if cls_name == 'SMEM':
# No-operand instructions
if op_name in ('s_gl1_inv', 's_dcache_inv'): return op_name
sdata, sbase, soffset, offset = unwrap(inst._values['sdata']), unwrap(inst._values['sbase']), unwrap(inst._values['soffset']), unwrap(inst._values['offset'])
glc, dlc = unwrap(inst._values.get('glc', 0)), unwrap(inst._values.get('dlc', 0))
# s_atc_probe/s_atc_probe_buffer: sdata is the probe mode (0-7), not a register
if op_name in ('s_atc_probe', 's_atc_probe_buffer'):
sbase_idx = sbase * 2
sbase_cnt = 4 if op_name == 's_atc_probe_buffer' else 2
sbase_str = _sreg(sbase_idx, sbase_cnt)
if offset and soffset != 124:
off_str = f"{decode_src(soffset)} offset:0x{offset:x}"
elif offset:
off_str = f"0x{offset:x}"
else:
off_str = decode_src(soffset)
return f"{op_name} {sdata}, {sbase_str}, {off_str}"
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val, 1)
off_str = f"0x{offset:x}" if offset else "null" if soffset == 124 else decode_src(soffset)
return f"{op_name} {_sreg(sdata, width)}, {_sreg(sbase, 2)}, {off_str}"
# Offset handling: if offset is set, we need "soffset offset:X" format, otherwise just soffset or imm
if offset and soffset != 124: # both soffset register and offset immediate
off_str = f"{decode_src(soffset)} offset:0x{offset:x}"
elif offset: # only offset immediate (soffset=null)
off_str = f"0x{offset:x}"
elif soffset == 124: # null
off_str = "null"
else: # only soffset register
off_str = decode_src(soffset)
# sbase is stored as register pair index, multiply by 2 for actual register number
# s_buffer_load_* (op 8-12) use 4-reg sbase (buffer descriptor), s_load_* (op 0-4) use 2-reg sbase
sbase_idx = sbase * 2
sbase_cnt = 4 if 8 <= op_val <= 12 else 2
# Format sbase with special register names
if sbase_idx == 106 and sbase_cnt == 2: sbase_str = "vcc"
elif sbase_idx == 126 and sbase_cnt == 2: sbase_str = "exec"
elif 108 <= sbase_idx <= 123: sbase_str = f"ttmp[{sbase_idx-108}:{sbase_idx-108+sbase_cnt-1}]"
else: sbase_str = _sreg(sbase_idx, sbase_cnt)
# Build modifiers
mods = []
if glc: mods.append("glc")
if dlc: mods.append("dlc")
mod_str = " " + " ".join(mods) if mods else ""
return f"{op_name} {_fmt_sdst(sdata, width)}, {sbase_str}, {off_str}{mod_str}"
# FLAT
if cls_name == 'FLAT':
@ -312,19 +352,37 @@ def disasm(inst: Inst) -> str:
return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}" + fmt_opsel(3) + clamp_str + omod_str
return f"{op_name} {dst_str}, {src0_str}, {src1_str}" + fmt_opsel(2) + clamp_str + omod_str
# VOP3SD: 3-source with scalar destination (v_div_scale_*)
# VOP3SD: 3-source with scalar destination (v_div_scale_*, v_add_co_u32, v_mad_*64_*32, etc.)
if cls_name == 'VOP3SD':
vdst, sdst = unwrap(inst._values.get('vdst', 0)), unwrap(inst._values.get('sdst', 0))
src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')]
neg = unwrap(inst._values.get('neg', 0))
def fmt_vop3_src(v, neg_bit):
omod = unwrap(inst._values.get('omod', 0))
clmp = unwrap(inst._values.get('clmp', 0))
is_f64 = 'f64' in op_name
is_mad64 = 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name
def fmt_sd_src(v, neg_bit, is_64bit=False):
s = fmt_src(v)
if is_64bit or is_f64:
if v >= 256: s = _vreg(v - 256, 2)
elif v <= 105: s = _sreg(v, 2)
elif v == 106: s = "vcc"
elif v == 126: s = "exec"
elif 108 <= v <= 123: s = f"ttmp[{v-108}:{v-108+1}]"
if neg_bit: s = f"-{s}"
return s
src0_str = fmt_vop3_src(src0, neg & 1)
src1_str = fmt_vop3_src(src1, neg & 2)
src2_str = fmt_vop3_src(src2, neg & 4)
return f"{op_name} v{vdst}, vcc_lo, {src0_str}, {src1_str}, {src2_str}"
src0_str = fmt_sd_src(src0, neg & 1, False)
src1_str = fmt_sd_src(src1, neg & 2, False)
src2_str = fmt_sd_src(src2, neg & 4, is_mad64)
dst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}"
sdst_str = _fmt_sdst(sdst, 1)
clamp_str = " clamp" if clmp else ""
omod_str = {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "")
# v_add_co_u32, v_sub_co_u32, v_subrev_co_u32, v_add_co_ci_u32, etc. only use 2 sources
if op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'):
return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}" + clamp_str
# v_div_scale, v_mad_*64_*32 use 3 sources
return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" + clamp_str + omod_str
# VOPD: dual-issue instructions
if cls_name == 'VOPD':

View file

@ -39,6 +39,8 @@ class Imm: pass
class SImm: pass
class RawImm:
def __init__(self, val: int): self.val = val
def __repr__(self): return f"RawImm({self.val})"
def __eq__(self, other): return isinstance(other, RawImm) and self.val == other.val
def unwrap(val) -> int:
return val.val if isinstance(val, RawImm) else val.value if hasattr(val, 'value') else val.idx if hasattr(val, 'idx') else val
@ -53,7 +55,9 @@ def encode_src(val) -> int:
if isinstance(val, VGPR): return 256 + val.idx + (0x80 if val.hi else 0)
if isinstance(val, TTMP): return 108 + val.idx
if hasattr(val, 'value'): return val.value
if isinstance(val, float): return FLOAT_ENC.get(val, 255)
if isinstance(val, float):
if val == 0.0: return 128 # 0.0 encodes as integer constant 0
return FLOAT_ENC.get(val, 255)
return 128 + val if isinstance(val, int) and 0 <= val <= 64 else 192 + (-val) if isinstance(val, int) and -16 <= val <= -1 else 255
# Instruction base class
@ -71,6 +75,37 @@ class Inst:
self._values, self._literal = dict(self._defaults), literal
self._values.update(zip([n for n in self._fields if n != 'encoding'], args))
self._values.update(kwargs)
# Get annotations from class hierarchy
annotations = {}
for cls in type(self).__mro__:
annotations.update(getattr(cls, '__annotations__', {}))
# Type check and encode values
for name, val in list(self._values.items()):
if name == 'encoding' or isinstance(val, RawImm): continue
ann = annotations.get(name)
# Type validation
if ann is SGPR:
if isinstance(val, VGPR): raise TypeError(f"field '{name}' requires SGPR, got VGPR")
if not isinstance(val, (SGPR, TTMP, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}")
if ann is VGPR:
if not isinstance(val, VGPR): raise TypeError(f"field '{name}' requires VGPR, got {type(val).__name__}")
if ann is SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR")
# Encode source fields as RawImm for consistent disassembly
if name in SRC_FIELDS:
encoded = encode_src(val)
self._values[name] = RawImm(encoded)
# Track literal value if needed (encoded as 255)
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
self._literal = val
# Encode raw register fields for consistent repr
elif name in RAW_FIELDS and isinstance(val, Reg):
encoded = (108 + val.idx) if isinstance(val, TTMP) else (val.idx | (0x80 if val.hi else 0))
self._values[name] = encoded
# Encode sbase (divided by 2) and srsrc/ssamp (divided by 4)
elif name == 'sbase' and isinstance(val, Reg):
self._values[name] = val.idx // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg):
self._values[name] = val.idx // 4
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
@ -120,7 +155,11 @@ class Inst:
if has_literal and len(data) >= cls._size() + 4: inst._literal = int.from_bytes(data[cls._size():cls._size()+4], 'little')
return inst
def __repr__(self): return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in self._values.items())})"
def __repr__(self):
# Use _fields order and exclude fields that are 0/default (for consistent repr after roundtrip)
items = [(k, self._values[k]) for k in self._fields if k in self._values and k != 'encoding'
and not (isinstance(self._values[k], int) and self._values[k] == 0 and k not in {'op'})]
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)})"
def disasm(self) -> str:
from extra.assembly.rdna3.asm import disasm

View file

@ -8,36 +8,51 @@ from extra.assembly.rdna3.test.test_roundtrip import compile_asm
class TestIntegration(unittest.TestCase):
def tearDown(self):
if not hasattr(self, 'inst'): return
b = self.inst.to_bytes()
st = self.inst.disasm()
reasm = asm(st)
desc = f"{self.inst} {b} {st} {reasm}"
desc = f"{st:25s} {self.inst} {b} {reasm}"
self.assertEqual(b, compile_asm(st), desc)
# TODO: this compare should work for valid things
#self.assertEqual(self.inst, reasm)
self.assertEqual(repr(self.inst), repr(reasm))
print(desc)
def test_load_b128(self):
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0)
def test_load_b128_s(self):
self.inst = s_load_b128(s[4:7], s[0:1], s[8], 0)
def test_load_b128_v(self):
with self.assertRaises(TypeError):
self.inst = s_load_b128(s[4:7], s[0:1], v[8], 0)
def test_load_b128_off(self):
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 3)
def test_simple_stos(self):
self.inst = s_mov_b32(s[0], s[1])
def test_simple_wrong(self):
# TODO: this should raise an exception on construction, s[1] is not a valid type
with self.assertRaises(TypeError):
self.inst = s_mov_b32(v[0], s[1])
def test_simple_vtov(self):
# TODO: this is broken, it's reconstructing with s[1] and not v[1]
self.inst = v_mov_b32_e32(v[0], v[1])
def test_simple_stov(self):
self.inst = v_mov_b32_e32(v[0], s[2])
def test_simple_float_to_v(self):
# TODO: this should be the magic float value 1.0
self.inst = v_mov_b32_e32(v[0], 1.0)
def test_simple_v_to_float(self):
with self.assertRaises(TypeError):
self.inst = v_mov_b32_e32(1, v[0])
def test_simple_int_to_v(self):
# TODO: this should be the constant 1, not s[0]
self.inst = v_mov_b32_e32(v[0], 1)
if __name__ == "__main__":

View file

@ -119,6 +119,8 @@ def _make_disasm_test(name):
decoded = fmt_cls.from_bytes(data)
op_val = decoded._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
# VOP3SD test uses VOP3 file - skip non-VOP3SD instructions
if name == 'vop3sd' and op_val not in vop3sd_opcodes: continue
# VOP3 and VOP3SD share encoding - validate with appropriate enum
if fmt_cls.__name__ == 'VOP3' and op_val in vop3sd_opcodes:
VOP3SDOp(op_val) # validate as VOP3SD