mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
more handwritten
This commit is contained in:
parent
f6d68f2090
commit
d41bb12a13
4 changed files with 131 additions and 17 deletions
|
|
@ -22,8 +22,11 @@ def _vreg(base: int, cnt: int = 1) -> str: return f"v{base}" if cnt == 1 else f"
|
|||
def _fmt_sdst(v: int, cnt: int = 1) -> str:
|
||||
"""Format SGPR destination with special register names."""
|
||||
if v == 124: return "null"
|
||||
if 108 <= v <= 123: return f"ttmp[{v-108}:{v-108+cnt-1}]" if cnt == 2 else f"ttmp{v-108}"
|
||||
if cnt == 2: return "exec" if v == 126 else "vcc" if v == 106 else _sreg(v, 2)
|
||||
if 108 <= v <= 123: return f"ttmp[{v-108}:{v-108+cnt-1}]" if cnt > 1 else f"ttmp{v-108}"
|
||||
if cnt > 1:
|
||||
if v == 126 and cnt == 2: return "exec"
|
||||
if v == 106 and cnt == 2: return "vcc"
|
||||
return _sreg(v, cnt)
|
||||
return {126: "exec_lo", 127: "exec_hi", 106: "vcc_lo", 107: "vcc_hi", 125: "m0"}.get(v, f"s{v}")
|
||||
|
||||
def _fmt_ssrc(v: int, cnt: int = 1) -> str:
|
||||
|
|
@ -152,10 +155,47 @@ def disasm(inst: Inst) -> str:
|
|||
|
||||
# SMEM
|
||||
if cls_name == 'SMEM':
|
||||
# No-operand instructions
|
||||
if op_name in ('s_gl1_inv', 's_dcache_inv'): return op_name
|
||||
sdata, sbase, soffset, offset = unwrap(inst._values['sdata']), unwrap(inst._values['sbase']), unwrap(inst._values['soffset']), unwrap(inst._values['offset'])
|
||||
glc, dlc = unwrap(inst._values.get('glc', 0)), unwrap(inst._values.get('dlc', 0))
|
||||
# s_atc_probe/s_atc_probe_buffer: sdata is the probe mode (0-7), not a register
|
||||
if op_name in ('s_atc_probe', 's_atc_probe_buffer'):
|
||||
sbase_idx = sbase * 2
|
||||
sbase_cnt = 4 if op_name == 's_atc_probe_buffer' else 2
|
||||
sbase_str = _sreg(sbase_idx, sbase_cnt)
|
||||
if offset and soffset != 124:
|
||||
off_str = f"{decode_src(soffset)} offset:0x{offset:x}"
|
||||
elif offset:
|
||||
off_str = f"0x{offset:x}"
|
||||
else:
|
||||
off_str = decode_src(soffset)
|
||||
return f"{op_name} {sdata}, {sbase_str}, {off_str}"
|
||||
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val, 1)
|
||||
off_str = f"0x{offset:x}" if offset else "null" if soffset == 124 else decode_src(soffset)
|
||||
return f"{op_name} {_sreg(sdata, width)}, {_sreg(sbase, 2)}, {off_str}"
|
||||
# Offset handling: if offset is set, we need "soffset offset:X" format, otherwise just soffset or imm
|
||||
if offset and soffset != 124: # both soffset register and offset immediate
|
||||
off_str = f"{decode_src(soffset)} offset:0x{offset:x}"
|
||||
elif offset: # only offset immediate (soffset=null)
|
||||
off_str = f"0x{offset:x}"
|
||||
elif soffset == 124: # null
|
||||
off_str = "null"
|
||||
else: # only soffset register
|
||||
off_str = decode_src(soffset)
|
||||
# sbase is stored as register pair index, multiply by 2 for actual register number
|
||||
# s_buffer_load_* (op 8-12) use 4-reg sbase (buffer descriptor), s_load_* (op 0-4) use 2-reg sbase
|
||||
sbase_idx = sbase * 2
|
||||
sbase_cnt = 4 if 8 <= op_val <= 12 else 2
|
||||
# Format sbase with special register names
|
||||
if sbase_idx == 106 and sbase_cnt == 2: sbase_str = "vcc"
|
||||
elif sbase_idx == 126 and sbase_cnt == 2: sbase_str = "exec"
|
||||
elif 108 <= sbase_idx <= 123: sbase_str = f"ttmp[{sbase_idx-108}:{sbase_idx-108+sbase_cnt-1}]"
|
||||
else: sbase_str = _sreg(sbase_idx, sbase_cnt)
|
||||
# Build modifiers
|
||||
mods = []
|
||||
if glc: mods.append("glc")
|
||||
if dlc: mods.append("dlc")
|
||||
mod_str = " " + " ".join(mods) if mods else ""
|
||||
return f"{op_name} {_fmt_sdst(sdata, width)}, {sbase_str}, {off_str}{mod_str}"
|
||||
|
||||
# FLAT
|
||||
if cls_name == 'FLAT':
|
||||
|
|
@ -312,19 +352,37 @@ def disasm(inst: Inst) -> str:
|
|||
return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}" + fmt_opsel(3) + clamp_str + omod_str
|
||||
return f"{op_name} {dst_str}, {src0_str}, {src1_str}" + fmt_opsel(2) + clamp_str + omod_str
|
||||
|
||||
# VOP3SD: 3-source with scalar destination (v_div_scale_*)
|
||||
# VOP3SD: 3-source with scalar destination (v_div_scale_*, v_add_co_u32, v_mad_*64_*32, etc.)
|
||||
if cls_name == 'VOP3SD':
|
||||
vdst, sdst = unwrap(inst._values.get('vdst', 0)), unwrap(inst._values.get('sdst', 0))
|
||||
src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')]
|
||||
neg = unwrap(inst._values.get('neg', 0))
|
||||
def fmt_vop3_src(v, neg_bit):
|
||||
omod = unwrap(inst._values.get('omod', 0))
|
||||
clmp = unwrap(inst._values.get('clmp', 0))
|
||||
is_f64 = 'f64' in op_name
|
||||
is_mad64 = 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name
|
||||
def fmt_sd_src(v, neg_bit, is_64bit=False):
|
||||
s = fmt_src(v)
|
||||
if is_64bit or is_f64:
|
||||
if v >= 256: s = _vreg(v - 256, 2)
|
||||
elif v <= 105: s = _sreg(v, 2)
|
||||
elif v == 106: s = "vcc"
|
||||
elif v == 126: s = "exec"
|
||||
elif 108 <= v <= 123: s = f"ttmp[{v-108}:{v-108+1}]"
|
||||
if neg_bit: s = f"-{s}"
|
||||
return s
|
||||
src0_str = fmt_vop3_src(src0, neg & 1)
|
||||
src1_str = fmt_vop3_src(src1, neg & 2)
|
||||
src2_str = fmt_vop3_src(src2, neg & 4)
|
||||
return f"{op_name} v{vdst}, vcc_lo, {src0_str}, {src1_str}, {src2_str}"
|
||||
src0_str = fmt_sd_src(src0, neg & 1, False)
|
||||
src1_str = fmt_sd_src(src1, neg & 2, False)
|
||||
src2_str = fmt_sd_src(src2, neg & 4, is_mad64)
|
||||
dst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}"
|
||||
sdst_str = _fmt_sdst(sdst, 1)
|
||||
clamp_str = " clamp" if clmp else ""
|
||||
omod_str = {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "")
|
||||
# v_add_co_u32, v_sub_co_u32, v_subrev_co_u32, v_add_co_ci_u32, etc. only use 2 sources
|
||||
if op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'):
|
||||
return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}" + clamp_str
|
||||
# v_div_scale, v_mad_*64_*32 use 3 sources
|
||||
return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" + clamp_str + omod_str
|
||||
|
||||
# VOPD: dual-issue instructions
|
||||
if cls_name == 'VOPD':
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class Imm: pass
|
|||
class SImm: pass
|
||||
class RawImm:
|
||||
def __init__(self, val: int): self.val = val
|
||||
def __repr__(self): return f"RawImm({self.val})"
|
||||
def __eq__(self, other): return isinstance(other, RawImm) and self.val == other.val
|
||||
|
||||
def unwrap(val) -> int:
|
||||
return val.val if isinstance(val, RawImm) else val.value if hasattr(val, 'value') else val.idx if hasattr(val, 'idx') else val
|
||||
|
|
@ -53,7 +55,9 @@ def encode_src(val) -> int:
|
|||
if isinstance(val, VGPR): return 256 + val.idx + (0x80 if val.hi else 0)
|
||||
if isinstance(val, TTMP): return 108 + val.idx
|
||||
if hasattr(val, 'value'): return val.value
|
||||
if isinstance(val, float): return FLOAT_ENC.get(val, 255)
|
||||
if isinstance(val, float):
|
||||
if val == 0.0: return 128 # 0.0 encodes as integer constant 0
|
||||
return FLOAT_ENC.get(val, 255)
|
||||
return 128 + val if isinstance(val, int) and 0 <= val <= 64 else 192 + (-val) if isinstance(val, int) and -16 <= val <= -1 else 255
|
||||
|
||||
# Instruction base class
|
||||
|
|
@ -71,6 +75,37 @@ class Inst:
|
|||
self._values, self._literal = dict(self._defaults), literal
|
||||
self._values.update(zip([n for n in self._fields if n != 'encoding'], args))
|
||||
self._values.update(kwargs)
|
||||
# Get annotations from class hierarchy
|
||||
annotations = {}
|
||||
for cls in type(self).__mro__:
|
||||
annotations.update(getattr(cls, '__annotations__', {}))
|
||||
# Type check and encode values
|
||||
for name, val in list(self._values.items()):
|
||||
if name == 'encoding' or isinstance(val, RawImm): continue
|
||||
ann = annotations.get(name)
|
||||
# Type validation
|
||||
if ann is SGPR:
|
||||
if isinstance(val, VGPR): raise TypeError(f"field '{name}' requires SGPR, got VGPR")
|
||||
if not isinstance(val, (SGPR, TTMP, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}")
|
||||
if ann is VGPR:
|
||||
if not isinstance(val, VGPR): raise TypeError(f"field '{name}' requires VGPR, got {type(val).__name__}")
|
||||
if ann is SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR")
|
||||
# Encode source fields as RawImm for consistent disassembly
|
||||
if name in SRC_FIELDS:
|
||||
encoded = encode_src(val)
|
||||
self._values[name] = RawImm(encoded)
|
||||
# Track literal value if needed (encoded as 255)
|
||||
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
||||
self._literal = val
|
||||
# Encode raw register fields for consistent repr
|
||||
elif name in RAW_FIELDS and isinstance(val, Reg):
|
||||
encoded = (108 + val.idx) if isinstance(val, TTMP) else (val.idx | (0x80 if val.hi else 0))
|
||||
self._values[name] = encoded
|
||||
# Encode sbase (divided by 2) and srsrc/ssamp (divided by 4)
|
||||
elif name == 'sbase' and isinstance(val, Reg):
|
||||
self._values[name] = val.idx // 2
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg):
|
||||
self._values[name] = val.idx // 4
|
||||
|
||||
def _encode_field(self, name: str, val) -> int:
|
||||
if isinstance(val, RawImm): return val.val
|
||||
|
|
@ -120,7 +155,11 @@ class Inst:
|
|||
if has_literal and len(data) >= cls._size() + 4: inst._literal = int.from_bytes(data[cls._size():cls._size()+4], 'little')
|
||||
return inst
|
||||
|
||||
def __repr__(self): return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in self._values.items())})"
|
||||
def __repr__(self):
|
||||
# Use _fields order and exclude fields that are 0/default (for consistent repr after roundtrip)
|
||||
items = [(k, self._values[k]) for k in self._fields if k in self._values and k != 'encoding'
|
||||
and not (isinstance(self._values[k], int) and self._values[k] == 0 and k not in {'op'})]
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)})"
|
||||
|
||||
def disasm(self) -> str:
|
||||
from extra.assembly.rdna3.asm import disasm
|
||||
|
|
|
|||
|
|
@ -8,36 +8,51 @@ from extra.assembly.rdna3.test.test_roundtrip import compile_asm
|
|||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
if not hasattr(self, 'inst'): return
|
||||
b = self.inst.to_bytes()
|
||||
st = self.inst.disasm()
|
||||
reasm = asm(st)
|
||||
desc = f"{self.inst} {b} {st} {reasm}"
|
||||
desc = f"{st:25s} {self.inst} {b} {reasm}"
|
||||
self.assertEqual(b, compile_asm(st), desc)
|
||||
# TODO: this compare should work for valid things
|
||||
#self.assertEqual(self.inst, reasm)
|
||||
self.assertEqual(repr(self.inst), repr(reasm))
|
||||
print(desc)
|
||||
|
||||
def test_load_b128(self):
|
||||
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0)
|
||||
|
||||
def test_load_b128_s(self):
|
||||
self.inst = s_load_b128(s[4:7], s[0:1], s[8], 0)
|
||||
|
||||
def test_load_b128_v(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.inst = s_load_b128(s[4:7], s[0:1], v[8], 0)
|
||||
|
||||
def test_load_b128_off(self):
|
||||
self.inst = s_load_b128(s[4:7], s[0:1], NULL, 3)
|
||||
|
||||
def test_simple_stos(self):
|
||||
self.inst = s_mov_b32(s[0], s[1])
|
||||
|
||||
def test_simple_wrong(self):
|
||||
# TODO: this should raise an exception on construction, s[1] is not a valid type
|
||||
with self.assertRaises(TypeError):
|
||||
self.inst = s_mov_b32(v[0], s[1])
|
||||
|
||||
def test_simple_vtov(self):
|
||||
# TODO: this is broken, it's reconstructing with s[1] and not v[1]
|
||||
self.inst = v_mov_b32_e32(v[0], v[1])
|
||||
|
||||
def test_simple_stov(self):
|
||||
self.inst = v_mov_b32_e32(v[0], s[2])
|
||||
|
||||
def test_simple_float_to_v(self):
|
||||
# TODO: this should be the magic float value 1.0
|
||||
self.inst = v_mov_b32_e32(v[0], 1.0)
|
||||
|
||||
def test_simple_v_to_float(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.inst = v_mov_b32_e32(1, v[0])
|
||||
|
||||
def test_simple_int_to_v(self):
|
||||
# TODO: this should be the constant 1, not s[0]
|
||||
self.inst = v_mov_b32_e32(v[0], 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -119,6 +119,8 @@ def _make_disasm_test(name):
|
|||
decoded = fmt_cls.from_bytes(data)
|
||||
op_val = decoded._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
# VOP3SD test uses VOP3 file - skip non-VOP3SD instructions
|
||||
if name == 'vop3sd' and op_val not in vop3sd_opcodes: continue
|
||||
# VOP3 and VOP3SD share encoding - validate with appropriate enum
|
||||
if fmt_cls.__name__ == 'VOP3' and op_val in vop3sd_opcodes:
|
||||
VOP3SDOp(op_val) # validate as VOP3SD
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue