mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assembly/amd: mypy+ruff passes (#14701)
* assembly/amd: mypy+ruff passes * touchups
This commit is contained in:
parent
095a064ba8
commit
d5fc3ea1ba
32 changed files with 593 additions and 377 deletions
|
|
@ -20,13 +20,13 @@ test_llvm.py tests asm/disasm on the LLVM tests, confirming it behaves the same
|
|||
|
||||
tinygrad's dtype tests should pass with and without LLVM. they run in about 12 seconds.
|
||||
|
||||
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_dtype_alu.py test/test_dtype.py`
|
||||
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_dtype_alu.py test/test_dtype.py`
|
||||
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/backend/test_dtype_alu.py test/backend/test_dtype.py`
|
||||
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/backend/test_dtype_alu.py test/backend/test_dtype.py`
|
||||
|
||||
The ops tests also pass, but they are very slow, so you should run them one at a time.
|
||||
|
||||
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_ops.py`
|
||||
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_ops.py`
|
||||
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/backend/test_ops.py`
|
||||
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/backend/test_ops.py`
|
||||
|
||||
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`.
|
||||
While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware
|
||||
|
|
@ -34,6 +34,6 @@ If a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's because an instru
|
|||
You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
|
||||
IMPORTANT: if a test is failing in the emulator, it's an instruction bug. Use DEBUG=7, get the instructions, and debug.
|
||||
|
||||
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
|
||||
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~3000 lines.
|
||||
Get line count with `cloc --by-file extra/assembly/amd/*.py`
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# autogenerated from AMD ISA XML - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from extra.assembly.amd.dsl import *
|
||||
from extra.assembly.amd.autogen.cdna.enum import *
|
||||
# ruff: noqa: E501,F401
|
||||
from extra.assembly.amd.dsl import BitField, DPP, DPP16, EXEC, EXECZ, EXEC_HI, EXEC_LO, EnumBitField, FixedBitField, INV_2PI, Inst, LIT, M0, NULL, OFF, SBaseField, SCC, SDWA, SGPRField, SRC_LDS_DIRECT, SRsrcField, SSrcField, SrcField, VCC, VCCZ, VCC_HI, VCC_LO, VGPRField, s, src, ttmp, v
|
||||
from extra.assembly.amd.autogen.cdna.enum import DSOp, FLATOp, GLOBALOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3PX2Op, VOP3SDOp, VOPCOp, HWREG
|
||||
import functools
|
||||
|
||||
class DS(Inst):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# autogenerated from AMD ISA XML - do not edit
|
||||
from extra.assembly.amd.autogen.common import Fmt, OpType
|
||||
from extra.assembly.amd.autogen.cdna.enum import *
|
||||
from extra.assembly.amd.autogen.cdna.enum import DSOp, FLATOp, GLOBALOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3PX2Op, VOP3SDOp, VOPCOp
|
||||
|
||||
# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}
|
||||
OPERANDS = {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# autogenerated from AMD ISA XML - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from extra.assembly.amd.dsl import *
|
||||
from extra.assembly.amd.autogen.rdna3.enum import *
|
||||
# ruff: noqa: E501,F401
|
||||
from extra.assembly.amd.dsl import BitField, DPP, DPP16, EXEC, EXECZ, EXEC_HI, EXEC_LO, EnumBitField, FixedBitField, INV_2PI, Inst, LIT, M0, NULL, OFF, SBaseField, SCC, SDWA, SGPRField, SRC_LDS_DIRECT, SRsrcField, SSrcField, SrcField, VCC, VCCZ, VCC_HI, VCC_LO, VDSTYField, VGPRField, s, src, ttmp, v
|
||||
from extra.assembly.amd.autogen.rdna3.enum import DSOp, EXPOp, FLATOp, GLOBALOp, LDSDIROp, MIMGOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp, HWREG, MSG
|
||||
import functools
|
||||
|
||||
class DS(Inst):
|
||||
|
|
@ -593,9 +593,6 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8)
|
|||
flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16)
|
||||
flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8)
|
||||
flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16)
|
||||
global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32)
|
||||
global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32)
|
||||
global_load_lds_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_LDS_ADDTID_B32)
|
||||
flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32)
|
||||
flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32)
|
||||
flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# autogenerated from AMD ISA XML - do not edit
|
||||
from extra.assembly.amd.autogen.common import Fmt, OpType
|
||||
from extra.assembly.amd.autogen.rdna3.enum import *
|
||||
from extra.assembly.amd.autogen.rdna3.enum import DSOp, EXPOp, FLATOp, GLOBALOp, LDSDIROp, MIMGOp, MTBUFOp, MUBUFOp, SCRATCHOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp
|
||||
|
||||
# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}
|
||||
OPERANDS = {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# autogenerated from AMD ISA XML - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from extra.assembly.amd.dsl import *
|
||||
from extra.assembly.amd.autogen.rdna4.enum import *
|
||||
# ruff: noqa: E501,F401
|
||||
from extra.assembly.amd.dsl import BitField, DPP, DPP16, EXEC, EXECZ, EXEC_HI, EXEC_LO, EnumBitField, FixedBitField, INV_2PI, Inst, LIT, M0, NULL, OFF, SBaseField, SCC, SDWA, SGPRField, SRC_LDS_DIRECT, SSrcField, SrcField, VCC, VCCZ, VCC_HI, VCC_LO, VDSTYField, VGPRField, s, src, ttmp, v
|
||||
from extra.assembly.amd.autogen.rdna4.enum import DSOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VBUFFEROp, VDSDIROp, VEXPORTOp, VFLATOp, VGLOBALOp, VIMAGEOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp, VSAMPLEOp, VSCRATCHOp, HWREG, MSG
|
||||
import functools
|
||||
|
||||
class DS(Inst):
|
||||
|
|
@ -973,8 +973,6 @@ flat_load_d16_hi_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_I8)
|
|||
flat_load_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_B16)
|
||||
flat_store_d16_hi_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B8)
|
||||
flat_store_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B16)
|
||||
global_load_addtid_b32 = functools.partial(VFLAT, VFLATOp.GLOBAL_LOAD_ADDTID_B32)
|
||||
global_store_addtid_b32 = functools.partial(VFLAT, VFLATOp.GLOBAL_STORE_ADDTID_B32)
|
||||
flat_atomic_swap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B32)
|
||||
flat_atomic_cmpswap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B32)
|
||||
flat_atomic_add_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U32)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# autogenerated from AMD ISA XML - do not edit
|
||||
from extra.assembly.amd.autogen.common import Fmt, OpType
|
||||
from extra.assembly.amd.autogen.rdna4.enum import *
|
||||
from extra.assembly.amd.autogen.rdna4.enum import DSOp, SMEMOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VBUFFEROp, VDSDIROp, VEXPORTOp, VFLATOp, VGLOBALOp, VIMAGEOp, VINTERPOp, VOP1Op, VOP2Op, VOP3Op, VOP3POp, VOP3SDOp, VOPCOp, VOPDOp, VSAMPLEOp, VSCRATCHOp
|
||||
|
||||
# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}
|
||||
OPERANDS = {
|
||||
|
|
|
|||
|
|
@ -44,11 +44,15 @@ class Reg:
|
|||
def fmt(self, sz=None, parens=False, upper=False) -> str:
|
||||
o, sz = self.offset, sz or self.sz
|
||||
l, r = ("[", "]") if parens or sz > 1 else ("", "") # brackets for multi-reg or when parens=True
|
||||
if 256 <= o < 512: idx = o - 256; base = f"v{l}{idx}{r}" if sz == 1 else f"v[{idx}:{idx + sz - 1}]"
|
||||
if 256 <= o < 512:
|
||||
idx = o - 256
|
||||
base = f"v{l}{idx}{r}" if sz == 1 else f"v[{idx}:{idx + sz - 1}]"
|
||||
elif o < 106: base = f"s{l}{o}{r}" if sz == 1 else f"s[{o}:{o + sz - 1}]"
|
||||
elif sz == 2 and o in self._PAIRS: base = self._PAIRS[o] if upper else self._PAIRS[o].lower()
|
||||
elif o in self._NAMES: base = self._NAMES[o] if upper else self._NAMES[o].lower() # special regs (any sz)
|
||||
elif 108 <= o < 124: idx = o - 108; base = f"ttmp{l}{idx}{r}" if sz == 1 else f"ttmp[{idx}:{idx + sz - 1}]"
|
||||
elif 108 <= o < 124:
|
||||
idx = o - 108
|
||||
base = f"ttmp{l}{idx}{r}" if sz == 1 else f"ttmp[{idx}:{idx + sz - 1}]"
|
||||
elif 128 <= o <= 192: base = str(o - 128) # inline int constants (0-64)
|
||||
elif 193 <= o <= 208: base = str(-(o - 192)) # inline negative int constants (-1 to -16)
|
||||
else: raise RuntimeError(f"unknown register: offset={o}, sz={sz}")
|
||||
|
|
@ -151,7 +155,8 @@ class SrcField(BitField):
|
|||
expected_size = self._valid_range[1] - self._valid_range[0] + 1
|
||||
actual_size = 1 << (hi - lo + 1)
|
||||
if actual_size != expected_size:
|
||||
raise RuntimeError(f"{self.__class__.__name__}: field size {hi - lo + 1} bits ({actual_size}) doesn't match range {self._valid_range} ({expected_size})")
|
||||
raise RuntimeError(f"{self.__class__.__name__}: field size {hi - lo + 1} bits ({actual_size}) "
|
||||
f"doesn't match range {self._valid_range} ({expected_size})")
|
||||
|
||||
def encode(self, val) -> int:
|
||||
"""Encode value. Returns 255 (literal marker) for out-of-range values."""
|
||||
|
|
@ -271,7 +276,7 @@ class Inst:
|
|||
inherited = {}
|
||||
for base in reversed(cls.__mro__[1:]):
|
||||
if hasattr(base, '_fields'):
|
||||
inherited.update({name: field for name, field in base._fields})
|
||||
inherited.update(dict(base._fields))
|
||||
inherited.update({name: val for name, val in cls.__dict__.items() if isinstance(val, BitField)})
|
||||
cls._fields = list(inherited.items())
|
||||
cls._base_size = (max(f.hi for _, f in cls._fields) + 8) // 8
|
||||
|
|
|
|||
|
|
@ -70,7 +70,8 @@ def _split64(val: UOp) -> tuple[UOp, UOp]:
|
|||
v64 = val.bitcast(dtypes.uint64) if val.dtype == dtypes.float64 else val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val
|
||||
return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32)
|
||||
|
||||
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF)}
|
||||
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF),
|
||||
64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF)}
|
||||
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp:
|
||||
"""Apply abs/neg modifiers to source value based on bit width (16, 32, or 64)."""
|
||||
if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val
|
||||
|
|
@ -163,21 +164,30 @@ def get_pcode(op) -> str:
|
|||
(f'1.0 / S1.{dt} == DENORM.{dt}', '0'), (f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})'),
|
||||
(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}'),
|
||||
(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}'),
|
||||
(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
|
||||
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})'),
|
||||
(f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\nif S0.{dt} == S2.{dt} then\n// Only scale the numerator\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
|
||||
f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\nD0.{dt} = S0.{dt}'),
|
||||
(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\nD0.{dt} = S0.{dt}\nendif\nelsif')]:
|
||||
(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\n'
|
||||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
|
||||
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\n'
|
||||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})'),
|
||||
(f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\n'
|
||||
f'if S0.{dt} == S2.{dt} then\n// Only scale the numerator\n'
|
||||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
|
||||
f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\n'
|
||||
f'VCC = 0x1LL;\nD0.{dt} = S0.{dt}'),
|
||||
(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif',
|
||||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\n'
|
||||
f'D0.{dt} = S0.{dt}\nendif\nelsif')]:
|
||||
pcode = pcode.replace(old, new)
|
||||
lines = pcode.rstrip().split('\n')
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if lines[i].strip() == 'endif': lines.insert(i, f'else\nD0.{dt} = S0.{dt}'); break
|
||||
if lines[i].strip() == 'endif':
|
||||
lines.insert(i, f'else\nD0.{dt} = S0.{dt}')
|
||||
break
|
||||
pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif'
|
||||
pcode = pcode.replace('VCC = 0x0LL', 'VCC.u64[laneId] = 0').replace('VCC = 0x1LL', 'VCC.u64[laneId] = 1')
|
||||
return pcode
|
||||
|
||||
def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
|
||||
vars: dict = srcs.copy() if srcs else {}
|
||||
env: dict = srcs.copy() if srcs else {}
|
||||
assigns: list[tuple[str, UOp]] = []
|
||||
raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
|
||||
# TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed
|
||||
|
|
@ -185,15 +195,17 @@ def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, l
|
|||
for l in raw_lines:
|
||||
if lines and lines[-1].endswith('&&'): lines[-1] = lines[-1] + ' ' + l
|
||||
else: lines.append(l)
|
||||
_, final, _ = parse_block(lines, 0, vars, assigns=assigns)
|
||||
_, final, _ = parse_block(lines, 0, env, assigns=assigns)
|
||||
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
|
||||
for var, val in final.items():
|
||||
if var in ['D0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
|
||||
if var in sliced and not any(re.match(rf'{var}\.\w+\s*=', l) for l in lines): continue
|
||||
for l in lines:
|
||||
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)): assigns.append((f'{var}.{m.group(1)}', val)); break
|
||||
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)):
|
||||
assigns.append((f'{var}.{m.group(1)}', val))
|
||||
break
|
||||
else: assigns.append((var, val))
|
||||
return vars, assigns
|
||||
return env, assigns
|
||||
|
||||
def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]:
|
||||
"""Write a 64-bit value as two 32-bit writes. args passed to wfn after reg/addr and lo/hi value."""
|
||||
|
|
@ -329,7 +341,8 @@ class _Ctx:
|
|||
# Dynamic register access (takes UOp index instead of int)
|
||||
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read SGPR with dynamic register index."""
|
||||
return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load() if valid is not None else self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load()
|
||||
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
|
||||
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
|
||||
"""Write SGPR with dynamic register index. Writes to NULL (124) are discarded."""
|
||||
|
|
@ -475,7 +488,8 @@ class _Ctx:
|
|||
if hi_bit != 31 or lo_bit != 0:
|
||||
width, slice_mask = hi_bit - lo_bit + 1, (1 << (hi_bit - lo_bit + 1)) - 1
|
||||
val_bits = val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.half else \
|
||||
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
|
||||
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else \
|
||||
val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
|
||||
raw_stores.append(('vgpr_slice', (lo_bit, width, val_bits)))
|
||||
continue
|
||||
# For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1]
|
||||
|
|
@ -484,7 +498,8 @@ class _Ctx:
|
|||
val = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0))
|
||||
if val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
lo, hi = _split64(val)
|
||||
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)), ('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))])
|
||||
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)),
|
||||
('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))])
|
||||
elif val.dtype in (dtypes.half, dtypes.uint16, dtypes.int16):
|
||||
result, old_val = _val_to_u32(val), self.rvgpr_dyn(vdst_reg, lane)
|
||||
hi_result = (old_val & UOp.const(dtypes.uint32, 0xFFFF)) | (result << UOp.const(dtypes.uint32, 16))
|
||||
|
|
@ -507,7 +522,7 @@ class _Ctx:
|
|||
if lane_stores: stores.append(UOp.sink(*lane_stores).end(lane))
|
||||
for mask_val, reg in [(vcc_val, vcc_reg), (exec_val, EXEC_LO.offset)]:
|
||||
if mask_val is None: continue
|
||||
get_bit = lambda l, v=mask_val: (_to_u32(v.substitute({lane: l})) & _c(1)).cast(dtypes.uint32)
|
||||
def get_bit(l, v=mask_val): return (_to_u32(v.substitute({lane: l})) & _c(1)).cast(dtypes.uint32)
|
||||
stores.append(self.wsgpr_dyn(_c(reg), self.unroll_lanes(get_bit, exec_mask, apply_exec=False)))
|
||||
stores.extend(scalar_stores)
|
||||
return UOp.sink(*stores, *self.inc_pc())
|
||||
|
|
@ -546,7 +561,7 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
|||
# Dynamic sdata field (bits 12:6) - destination SGPR
|
||||
sdata_reg = ctx.inst_field(type(inst).sdata)
|
||||
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
|
||||
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset
|
||||
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr]
|
||||
offset = ctx.inst_field_signed(offset_field) # signed immediate
|
||||
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0)
|
||||
soffset = ctx.inst_field(type(inst).soffset)
|
||||
|
|
@ -561,7 +576,7 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
|||
|
||||
def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp:
|
||||
bits = inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
|
||||
if isinstance(inst, (ir3.SOPK, ir4.SOPK)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
|
|
@ -598,7 +613,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
|||
op_name = _op_name(inst)
|
||||
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
|
||||
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||||
|
|
@ -641,7 +656,7 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op
|
|||
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
|
||||
if is_vopc:
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
vsrc1_off = ctx.inst_field(type(inst).vsrc1)
|
||||
vsrc1_off = ctx.inst_field(type(inst).vsrc1) # type: ignore[union-attr]
|
||||
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
|
||||
if bits['s0'] == 16:
|
||||
vsrc1_hi = vsrc1_off >= _c(128)
|
||||
|
|
@ -651,16 +666,17 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op
|
|||
src1_off = _c(256) + vsrc1_off
|
||||
else:
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
src1_off = ctx.inst_field(type(inst).src1)
|
||||
dst_off = ctx.inst_field(type(inst).vdst)
|
||||
src1_off = ctx.inst_field(type(inst).src1) # type: ignore[union-attr]
|
||||
dst_off = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr]
|
||||
vsrc1_hi = False
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
|
||||
is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op)
|
||||
def get_cmp_bit(lane) -> UOp:
|
||||
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 \
|
||||
else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
|
||||
if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
|
||||
if is_float:
|
||||
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0'])
|
||||
|
|
@ -701,7 +717,7 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
|
|||
# Regular VOP3 - read operands dynamically
|
||||
lane = ctx.range()
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
ops = inst.canonical_operands
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||||
|
|
@ -728,7 +744,7 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
|
|||
# Read operands dynamically from instruction encoding
|
||||
vdst_reg, sdst_off = ctx.inst_field(type(inst).vdst), ctx.inst_field(type(inst).sdst)
|
||||
src0_off, src1_off, src2_off = ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
|
||||
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
|
||||
vcc_in_off = src2_off if has_carry_in else sdst_off
|
||||
|
|
@ -853,7 +869,9 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
|||
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||||
return bits
|
||||
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
|
||||
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
|
||||
lo = get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit))
|
||||
hi = get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit))
|
||||
return lo | (hi << UOp.const(dtypes.uint32, 16))
|
||||
# DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation
|
||||
is_dot_iu = 'DOT' in op_name and 'IU' in op_name
|
||||
n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4)
|
||||
|
|
@ -908,11 +926,11 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
|
|||
|
||||
# Extract register info - all dynamic for deduplication
|
||||
if is_lds:
|
||||
addr_reg = ctx.inst_field(type(inst).addr)
|
||||
vdata_reg = ctx.inst_field(type(inst).data0)
|
||||
addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr]
|
||||
vdata_reg = ctx.inst_field(type(inst).data0) # type: ignore[union-attr]
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
offset0 = ctx.inst_field(type(inst).offset0)
|
||||
offset1 = ctx.inst_field(type(inst).offset1)
|
||||
offset0 = ctx.inst_field(type(inst).offset0) # type: ignore[union-attr]
|
||||
offset1 = ctx.inst_field(type(inst).offset1) # type: ignore[union-attr]
|
||||
offset = offset0 # DS uses offset0 as primary offset
|
||||
saddr_reg = None
|
||||
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
|
||||
|
|
@ -923,18 +941,18 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
|
|||
offset0, offset1 = _c(0), _c(0)
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
|
||||
else: # RDNA3: addr, data, offset
|
||||
addr_reg = ctx.inst_field(type(inst).addr)
|
||||
vdata_reg = ctx.inst_field(type(inst).data)
|
||||
addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr]
|
||||
vdata_reg = ctx.inst_field(type(inst).data) # type: ignore[union-attr]
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
offset = ctx.inst_field_signed(type(inst).offset)
|
||||
offset = ctx.inst_field_signed(type(inst).offset) # type: ignore[union-attr]
|
||||
offset0, offset1 = _c(0), _c(0)
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None # type: ignore[union-attr]
|
||||
|
||||
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
|
||||
data_bits_mem = inst.canonical_op_bits.get('data', 32)
|
||||
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
|
||||
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
|
||||
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0)
|
||||
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0) # type: ignore[union-attr]
|
||||
|
||||
# DS_PERMUTE/DS_BPERMUTE: cross-lane VGPR access via pcode
|
||||
if is_lds and 'PERMUTE' in op_name:
|
||||
|
|
@ -958,7 +976,8 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
|
|||
vaddr = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
|
||||
addr_offset = vaddr if sve == 1 else UOp.const(dtypes.uint64, 0)
|
||||
# Add saddr value only if use_saddr is true (saddr < 124)
|
||||
saddr_contrib = use_saddr.where(ctx.rsgpr_dyn(saddr_reg).cast(dtypes.uint64), UOp.const(dtypes.uint64, 0)) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||||
saddr_contrib = use_saddr.where(ctx.rsgpr_dyn(saddr_reg).cast(dtypes.uint64), UOp.const(dtypes.uint64, 0)) \
|
||||
if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||||
return base + addr_offset + saddr_contrib + offset64
|
||||
# FLAT/GLOBAL: choose between SGPR base (saddr) or VGPR pair (addr) based on saddr validity
|
||||
saddr_base = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||||
|
|
@ -1000,12 +1019,18 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
|
|||
vaddr_lo = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
|
||||
vaddr_base = use_saddr.where(vaddr_lo + ioffset64, vaddr_full + ioffset64)
|
||||
if is_atomic:
|
||||
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane),
|
||||
'_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
atomic_data = _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) \
|
||||
if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane)
|
||||
return {'ADDR': addr, 'DATA': atomic_data, '_vmem': mem, '_active': active,
|
||||
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name \
|
||||
else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
if 'STORE' in op_name and data_bits_mem >= 64:
|
||||
vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active,
|
||||
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
for i in range(data_bits_mem // 32):
|
||||
srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
return srcs
|
||||
|
||||
def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]:
|
||||
|
|
@ -1199,7 +1224,9 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
|
||||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
|
||||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
|
||||
if rsrc2 & enabled: st._write_sgpr(sgpr_idx, gid); sgpr_idx += 1
|
||||
if rsrc2 & enabled:
|
||||
st._write_sgpr(sgpr_idx, gid)
|
||||
sgpr_idx += 1
|
||||
|
||||
# RDNA4 uses TTMP registers for workgroup IDs: ttmp[9]=gidx, ttmp[10]=gidy, ttmp[11]=gidz
|
||||
if arch == "rdna4":
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# AMD ISA code generator - generates enum.py, ins.py, operands.py, str_pcode.py
|
||||
# Sources: XML from https://gpuopen.com/download/machine-readable-isa/latest/
|
||||
# PDF manuals from AMD documentation
|
||||
import re, zlib, xml.etree.ElementTree as ET, zipfile
|
||||
import re, zlib, xml.etree.ElementTree as ET, zipfile, pathlib
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -77,8 +77,13 @@ def parse_xml(filename: str):
|
|||
for ot in root.findall(".//OperandTypes/OperandType"):
|
||||
ot_name = ot.findtext("OperandTypeName")
|
||||
for field in ot.findall(".//Field"):
|
||||
if (enum_name := op_enum_map.get((ot_name, field.findtext("FieldName")))):
|
||||
enums[enum_name] = {int(pv.findtext("Value")): pv.findtext("Name").upper() for pv in field.findall(".//PredefinedValue")}
|
||||
key = (ot_name, field.findtext("FieldName"))
|
||||
if (enum_name := op_enum_map.get(key)): # type: ignore[arg-type]
|
||||
def _pv_val(pv: ET.Element) -> tuple[int, str]:
|
||||
v, n = pv.findtext("Value"), pv.findtext("Name")
|
||||
assert v is not None and n is not None
|
||||
return int(v), n.upper()
|
||||
enums[enum_name] = dict(_pv_val(pv) for pv in field.findall(".//PredefinedValue"))
|
||||
# Extract DataFormats with BitCount
|
||||
for df in root.findall("ISA/DataFormats/DataFormat"):
|
||||
name, bits = df.findtext("DataFormatName"), df.findtext("BitCount")
|
||||
|
|
@ -86,17 +91,26 @@ def parse_xml(filename: str):
|
|||
# Extract encoding definitions
|
||||
for enc in root.findall("ISA/Encodings/Encoding"):
|
||||
name = enc.findtext("EncodingName")
|
||||
assert name is not None
|
||||
is_base = name.startswith("ENC_") or name in ("VOP3_SDST_ENC", "VOPDXY")
|
||||
is_variant = any(sfx in name for sfx in _ENC_SUFFIX_MAP)
|
||||
if not is_base and not is_variant: continue
|
||||
if any(s in name for s in _SKIP_ENCODINGS): continue
|
||||
fields = [(_norm_field(f.findtext("FieldName").lower()), int(f.find("BitLayout/Range").findtext("BitOffset") or 0) + int(f.find("BitLayout/Range").findtext("BitCount") or 0) - 1,
|
||||
int(f.find("BitLayout/Range").findtext("BitOffset") or 0))
|
||||
for f in enc.findall(".//MicrocodeFormat/BitMap/Field") if f.find("BitLayout/Range") is not None]
|
||||
ident = (enc.findall("EncodingIdentifiers/EncodingIdentifier") or [None])[0]
|
||||
fields: list[tuple[str, int, int]] = []
|
||||
for f in enc.findall(".//MicrocodeFormat/BitMap/Field"):
|
||||
br = f.find("BitLayout/Range")
|
||||
if br is None: continue
|
||||
fn = f.findtext("FieldName")
|
||||
assert fn is not None
|
||||
fields.append((_norm_field(fn.lower()),
|
||||
int(br.findtext("BitOffset") or 0) + int(br.findtext("BitCount") or 0) - 1, int(br.findtext("BitOffset") or 0)))
|
||||
ident_list = enc.findall("EncodingIdentifiers/EncodingIdentifier")
|
||||
ident = ident_list[0] if ident_list else None
|
||||
enc_field = next((f for f in fields if f[0] == "encoding"), None)
|
||||
# For multi-dword formats, encoding field may be in higher dword but identifier pattern is always in dword0; use % 32
|
||||
enc_bits = "".join(ident.text[len(ident.text)-1-b] for b in range(enc_field[1] % 32, (enc_field[2] % 32)-1, -1)) if ident is not None and enc_field else None
|
||||
# For multi-dword formats, encoding field may be in higher dword but identifier is always in dword0; use % 32
|
||||
enc_bits: str | None = None
|
||||
if ident is not None and ident.text is not None and enc_field:
|
||||
enc_bits = "".join(ident.text[len(ident.text)-1-b] for b in range(enc_field[1] % 32, (enc_field[2] % 32)-1, -1))
|
||||
base_name = _strip_enc(name)
|
||||
encodings[NAME_MAP.get(base_name, base_name)] = (fields, enc_bits)
|
||||
# Extract instruction opcodes and operand info
|
||||
|
|
@ -104,9 +118,12 @@ def parse_xml(filename: str):
|
|||
opcode_encs: dict[str, dict[int, set[str]]] = {} # {base_fmt: {opcode: {enc_names}}}
|
||||
for instr in root.findall("ISA/Instructions/Instruction"):
|
||||
name = instr.findtext("InstructionName")
|
||||
assert name is not None
|
||||
for enc in instr.findall("InstructionEncodings/InstructionEncoding"):
|
||||
if enc.findtext("EncodingCondition") != "default": continue
|
||||
base, opcode = _map_flat(_strip_enc(enc.findtext("EncodingName")), name), int(enc.findtext("Opcode") or 0)
|
||||
enc_enc_name = enc.findtext("EncodingName")
|
||||
assert enc_enc_name is not None
|
||||
base, opcode = _map_flat(_strip_enc(enc_enc_name), name), int(enc.findtext("Opcode") or 0)
|
||||
enc_name = NAME_MAP.get(base, base)
|
||||
# Encoding variants use the same Op enum as the base format
|
||||
base_enum = enc_name
|
||||
|
|
@ -120,8 +137,10 @@ def parse_xml(filename: str):
|
|||
elif base == "VGLOBAL": enums.setdefault("VFLAT", {})[opcode] = name
|
||||
enums.setdefault(base_enum, {})[opcode] = name
|
||||
# Extract operand info
|
||||
op_info = {op.findtext("FieldName").lower(): (op.findtext("DataFormatName"), int(op.findtext("OperandSize") or 0), op.findtext("OperandType"))
|
||||
for op in enc.findall("Operands/Operand") if op.findtext("FieldName")}
|
||||
op_info: dict[str, tuple[str | None, int, str | None]] = {}
|
||||
for op in enc.findall("Operands/Operand"):
|
||||
fn = op.findtext("FieldName")
|
||||
if fn: op_info[fn.lower()] = (op.findtext("DataFormatName"), int(op.findtext("OperandSize") or 0), op.findtext("OperandType"))
|
||||
for fmt, _, otype in op_info.values():
|
||||
if fmt and fmt not in fmts: fmts[fmt] = 0
|
||||
if otype: op_types_set.add(otype)
|
||||
|
|
@ -143,7 +162,9 @@ def extract_pdf_text(url: str) -> list[list[tuple[float, float, str, str]]]:
|
|||
data = fetch(url).read_bytes()
|
||||
# Parse xref table to locate objects
|
||||
xref: dict[int, int] = {}
|
||||
pos = int(re.search(rb'startxref\s+(\d+)', data).group(1)) + 4
|
||||
xref_match = re.search(rb'startxref\s+(\d+)', data)
|
||||
assert xref_match is not None
|
||||
pos = int(xref_match.group(1)) + 4
|
||||
while data[pos:pos+7] != b'trailer':
|
||||
while data[pos:pos+1] in b' \r\n': pos += 1
|
||||
line_end = data.find(b'\n', pos)
|
||||
|
|
@ -164,14 +185,19 @@ def extract_pdf_text(url: str) -> list[list[tuple[float, float, str, str]]]:
|
|||
if not (m := re.search(rb'/Contents (\d+) 0 R', data[xref[n]:xref[n]+500])): continue
|
||||
stream = get_stream(int(m.group(1))).decode('latin-1')
|
||||
elements, font = [], ''
|
||||
_RE_BT = (r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm'
|
||||
r'|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ')
|
||||
for bt in re.finditer(r'BT(.*?)ET', stream, re.S):
|
||||
x, y = 0.0, 0.0
|
||||
for m in re.finditer(r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ', bt.group(1)):
|
||||
if m.group(1): font = m.group(1)
|
||||
elif m.group(2): x, y = x + float(m.group(2)), y + float(m.group(3))
|
||||
elif m.group(4): x, y = float(m.group(4)), float(m.group(5))
|
||||
elif m.group(6) and (t := bytes.fromhex(m.group(6)).decode('latin-1')).strip(): elements.append((x, y, t, font))
|
||||
elif m.group(7) and (t := ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', m.group(7)))).strip(): elements.append((x, y, t, font))
|
||||
for sm in re.finditer(_RE_BT, bt.group(1)):
|
||||
if sm.group(1): font = sm.group(1)
|
||||
elif sm.group(2): x, y = x + float(sm.group(2)), y + float(sm.group(3))
|
||||
elif sm.group(4): x, y = float(sm.group(4)), float(sm.group(5))
|
||||
elif sm.group(6) and (t := bytes.fromhex(sm.group(6)).decode('latin-1')).strip():
|
||||
elements.append((x, y, t, font))
|
||||
elif sm.group(7):
|
||||
t = ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', sm.group(7)))
|
||||
if t.strip(): elements.append((x, y, t, font))
|
||||
pages.append(sorted(elements, key=lambda e: (-e[1], e[0])))
|
||||
return pages
|
||||
|
||||
|
|
@ -197,7 +223,7 @@ def extract_pcode(pages: list[list[tuple[float, float, str, str]]], name_to_op:
|
|||
else:
|
||||
next_page, next_y = page_idx, 0
|
||||
# Collect F6 text from current position to next instruction (pseudocode is at x ≈ 69)
|
||||
lines = []
|
||||
lines: list[tuple[int, float, str]] = []
|
||||
for p in range(page_idx, next_page + 1):
|
||||
start_y = y if p == page_idx else 800
|
||||
end_y = next_y if p == next_page else 0
|
||||
|
|
@ -220,8 +246,8 @@ def extract_pcode(pages: list[list[tuple[float, float, str, str]]], name_to_op:
|
|||
# Code generation
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def write_common(all_fmts, all_op_types, path):
|
||||
lines = ["# autogenerated from AMD ISA XML - do not edit", "from enum import Enum, auto", ""]
|
||||
def write_common(all_fmts: dict[str, int], all_op_types: set[str], path: pathlib.Path) -> None:
|
||||
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit", "from enum import Enum, auto", ""]
|
||||
lines.append("class ReprEnum(Enum):")
|
||||
lines.append(' """Enum with clean repr that roundtrips with eval()."""')
|
||||
lines.append(' def __repr__(self): return f"{type(self).__name__}.{self.name}"')
|
||||
|
|
@ -238,7 +264,8 @@ def write_common(all_fmts, all_op_types, path):
|
|||
with open(path, "w") as f: f.write("\n".join(lines))
|
||||
|
||||
def write_enum(enums, path):
|
||||
lines = ["# autogenerated from AMD ISA XML - do not edit", "from extra.assembly.amd.autogen.common import ReprEnum, Fmt, FMT_BITS, OpType # noqa: F401", ""]
|
||||
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit",
|
||||
"from extra.assembly.amd.autogen.common import ReprEnum, Fmt, FMT_BITS, OpType # noqa: F401", ""]
|
||||
for name, ops in sorted(enums.items()):
|
||||
if not ops: continue
|
||||
suffix = "_E32" if name in ("VOP1", "VOP2", "VOPC") else "_E64" if name == "VOP3" else ""
|
||||
|
|
@ -286,7 +313,7 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
|
|||
'dpp', 'fi', 'bc', 'row_mask', 'bank_mask', 'src0_neg', 'src0_abs', 'src1_neg', 'src1_abs',
|
||||
'cbsz', 'abid', 'acc_cd', 'acc', 'blgp', 'lane_sel_0', 'lane_sel_1', 'lane_sel_2', 'lane_sel_3',
|
||||
'lane_sel_4', 'lane_sel_5', 'lane_sel_6', 'lane_sel_7', 'dst_sel', 'dst_unused', 'src0_sel', 'src1_sel']
|
||||
sort_fields = lambda fields: sorted(fields, key=lambda f: (ORDER.index(f[0]) if f[0] in ORDER else 999, f[2]))
|
||||
def sort_fields(fields): return sorted(fields, key=lambda f: (ORDER.index(f[0]) if f[0] in ORDER else 999, f[2]))
|
||||
|
||||
# Separate base encodings from variants
|
||||
base_encodings, variant_encodings = {}, {}
|
||||
|
|
@ -296,15 +323,29 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
|
|||
else: variant_encodings[enc_name] = data
|
||||
|
||||
# Build sets of ops by their vdst type from operand metadata
|
||||
sdst_opcodes = {} # ops where vdst is OPR_SREG (writes to SGPR)
|
||||
sdst_opcodes: dict[str, set[int]] = {} # ops where vdst is OPR_SREG (writes to SGPR)
|
||||
for fmt, ops in enums.items():
|
||||
for op, name in ops.items():
|
||||
op_types = types.get((name, fmt), {})
|
||||
vdst_type = op_types.get("vdst", (None, None, None))[2]
|
||||
if vdst_type == "OPR_SREG": sdst_opcodes.setdefault(fmt, set()).add(op)
|
||||
|
||||
lines = ["# autogenerated from AMD ISA XML - do not edit", "# ruff: noqa: F401,F403",
|
||||
"from extra.assembly.amd.dsl import *", f"from extra.assembly.amd.autogen.{arch}.enum import *", "import functools", ""]
|
||||
# collect only the XxxOp enums that are actually referenced in this arch's instruction definitions
|
||||
enum_names = sorted(f"{k}Op" for k in enums if enums[k] and k not in ("HWREG", "MSG"))
|
||||
# also re-export HWREG/MSG enums (plain enums, not instruction format ops)
|
||||
enum_names += sorted(k for k in enums if k in ("HWREG", "MSG") and enums[k])
|
||||
# collect DSL field types actually used by scanning generated field definitions
|
||||
all_field_defs = " ".join(field_def(fn, hi, lo, enc, eb) for enc, (flds, eb) in encodings.items() for fn, hi, lo in flds)
|
||||
_ALL_DSL = ["BitField", "EnumBitField", "FixedBitField", "NULL", "SBaseField", "SGPRField", "SRsrcField",
|
||||
"SSrcField", "SrcField", "VDSTYField", "VGPRField"]
|
||||
dsl_names = ["Inst"] + [n for n in _ALL_DSL if n in all_field_defs]
|
||||
# also re-export register names so `from ins import *` still provides them to downstream users
|
||||
_DSL_REGS = ["s", "v", "src", "VCC_LO", "VCC_HI", "VCC", "EXEC_LO", "EXEC_HI", "EXEC", "NULL", "OFF", "M0",
|
||||
"SCC", "VCCZ", "EXECZ", "ttmp", "INV_2PI", "SDWA", "DPP", "DPP16", "LIT", "SRC_LDS_DIRECT"]
|
||||
dsl_reexport = sorted(set(dsl_names + _DSL_REGS))
|
||||
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit", "# ruff: noqa: E501,F401",
|
||||
f"from extra.assembly.amd.dsl import {', '.join(dsl_reexport)}",
|
||||
f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", "import functools", ""]
|
||||
|
||||
def fmt_allowed(op_enum: str, ops: set[int]) -> str:
|
||||
"""Format allowed ops as {EnumName.MEMBER, ...}."""
|
||||
|
|
@ -323,7 +364,9 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
|
|||
has_seg_field = any(fn == "seg" for fn, _, _ in fields)
|
||||
if enc_name in ("FLAT", "VFLAT") and has_seg_field:
|
||||
prefix = "V" if enc_name == "VFLAT" else ""
|
||||
for cls, seg, op_enum in [(f"{prefix}FLAT", 0, f"{prefix}FLATOp"), (f"{prefix}GLOBAL", 2, f"{prefix}GLOBALOp"), (f"{prefix}SCRATCH", 1, f"{prefix}SCRATCHOp")]:
|
||||
flat_variants = [(f"{prefix}FLAT", 0, f"{prefix}FLATOp"), (f"{prefix}GLOBAL", 2, f"{prefix}GLOBALOp"),
|
||||
(f"{prefix}SCRATCH", 1, f"{prefix}SCRATCHOp")]
|
||||
for cls, seg, op_enum in flat_variants:
|
||||
cls_ops = set(enums.get(cls, {}).keys())
|
||||
lines.append(f"class {cls}(Inst):")
|
||||
for fn, hi, lo in sort_fields(fields):
|
||||
|
|
@ -396,6 +439,8 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
|
|||
op_to_suffix = {op:suffix for suffix,ops in suffix_only_ops.items() for op in ops.get(fmt, set())}
|
||||
fmt_sdst_ops = sdst_opcodes.get(fmt, set())
|
||||
for op, name in sorted(ops.items()):
|
||||
# ADDTID ops are in both FLAT and GLOBAL enums (for pcode); only generate helper for GLOBAL/VGLOBAL
|
||||
if "ADDTID" in name and fmt in ("FLAT", "VFLAT"): continue
|
||||
msuf = suffix if fmt != "VOP3" or op < 512 else ""
|
||||
# Determine class: SDST variants, suffix-specific variants (e.g., _MFMA, _LIT), or base
|
||||
if fmt == "VOP1" and op in fmt_sdst_ops: cls = "VOP1_SDST"
|
||||
|
|
@ -405,11 +450,14 @@ def write_ins(encodings, enums, suffix_only_ops, types, arch, path):
|
|||
lines.append(f"{name.lower()}{msuf.lower()} = functools.partial({cls}, {fmt}Op.{name}{msuf})")
|
||||
with open(path, "w") as f: f.write("\n".join(lines))
|
||||
|
||||
def write_operands(types, enums, arch, path):
|
||||
def write_operands(types: dict, enums: dict, arch: str, path: pathlib.Path) -> None:
|
||||
valid = {(name, fmt) for fmt, ops in enums.items() for name in ops.values()}
|
||||
lines = ["# autogenerated from AMD ISA XML - do not edit",
|
||||
"from extra.assembly.amd.autogen.common import Fmt, OpType",
|
||||
f"from extra.assembly.amd.autogen.{arch}.enum import *", ""]
|
||||
# only import enums that are actually used as keys in OPERANDS
|
||||
used_bases = {eb for (nm, eb) in types if (nm, eb) in valid}
|
||||
enum_names = sorted(f"{k}Op" for k in used_bases)
|
||||
lines: list[str] = ["# autogenerated from AMD ISA XML - do not edit",
|
||||
"from extra.assembly.amd.autogen.common import Fmt, OpType",
|
||||
f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", ""]
|
||||
lines.append("# instruction operand info: {Op: {field: (Fmt, size_bits, OpType)}}")
|
||||
lines.append("OPERANDS = {")
|
||||
def fmt_val(v):
|
||||
|
|
@ -422,7 +470,7 @@ def write_operands(types, enums, arch, path):
|
|||
lines.append("}")
|
||||
with open(path, "w") as f: f.write("\n".join(lines))
|
||||
|
||||
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: str):
|
||||
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: pathlib.Path) -> None:
|
||||
"""Write str_pcode.py file from extracted pseudocode."""
|
||||
entries: list[tuple[str, str, int, str]] = []
|
||||
for fmt_name, ops in enums.items():
|
||||
|
|
@ -444,8 +492,9 @@ def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, st
|
|||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pathlib
|
||||
all_fmts, all_op_types, arch_data = {}, set(), {}
|
||||
all_fmts: dict[str, int] = {}
|
||||
all_op_types: set[str] = set()
|
||||
arch_data: dict[str, dict] = {}
|
||||
# First pass: parse XML for all architectures
|
||||
for arch, cfg in ARCHS.items():
|
||||
print(f"Parsing XML: {cfg['xml']} -> {arch}")
|
||||
|
|
|
|||
|
|
@ -52,7 +52,9 @@ def _val_to_bits(val):
|
|||
if val.dtype == dtypes.float64: return val.bitcast(dtypes.uint64)
|
||||
return val if val.dtype == dtypes.uint32 else val.cast(dtypes.uint32)
|
||||
|
||||
def _floor(x): t = UOp(Ops.TRUNC, x.dtype, (x,)); return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t)
|
||||
def _floor(x):
|
||||
t = UOp(Ops.TRUNC, x.dtype, (x,))
|
||||
return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t)
|
||||
def _f16_extract(v): return (v & _u32(0xFFFF)).cast(dtypes.uint16).bitcast(dtypes.half) if v.dtype == dtypes.uint32 else v
|
||||
|
||||
def _check_nan(v: UOp, quiet: bool) -> UOp:
|
||||
|
|
@ -118,7 +120,8 @@ def _f_to_u(f, dt): return UOp(Ops.TRUNC, f.dtype, ((f < _const(f.dtype, 0.0)).w
|
|||
|
||||
def _cvt_quiet(val: UOp) -> UOp:
|
||||
bits, _, _, qb, _ = _float_info(val)
|
||||
bt, ft = (dtypes.uint64, dtypes.float64) if val.dtype == dtypes.float64 else (dtypes.uint16, dtypes.half) if val.dtype == dtypes.half else (dtypes.uint32, dtypes.float32)
|
||||
bt, ft = (dtypes.uint64, dtypes.float64) if val.dtype == dtypes.float64 else \
|
||||
(dtypes.uint16, dtypes.half) if val.dtype == dtypes.half else (dtypes.uint32, dtypes.float32)
|
||||
return (val.bitcast(bt) | qb).bitcast(ft)
|
||||
|
||||
def _is_denorm(val: UOp) -> UOp:
|
||||
|
|
@ -163,14 +166,18 @@ def _ldexp(val: UOp, exp: UOp) -> UOp:
|
|||
def _frexp_mant(val: UOp) -> UOp:
|
||||
val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val
|
||||
if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) & _u32(0x807FFFFF)) | _u32(0x3f000000)).bitcast(dtypes.float32)
|
||||
return ((val.bitcast(dtypes.uint64) & _const(dtypes.uint64, 0x800FFFFFFFFFFFFF)) | _const(dtypes.uint64, 0x3fe0000000000000)).bitcast(dtypes.float64)
|
||||
return ((val.bitcast(dtypes.uint64) & _const(dtypes.uint64, 0x800FFFFFFFFFFFFF)) |
|
||||
_const(dtypes.uint64, 0x3fe0000000000000)).bitcast(dtypes.float64)
|
||||
|
||||
def _frexp_exp(val: UOp) -> UOp:
|
||||
val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val
|
||||
if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) >> _u32(23)) & _u32(0xFF)).cast(dtypes.int) - _const(dtypes.int, 126)
|
||||
return ((val.bitcast(dtypes.uint64) >> _const(dtypes.uint64, 52)) & _const(dtypes.uint64, 0x7FF)).cast(dtypes.int) - _const(dtypes.int, 1022)
|
||||
|
||||
TWO_OVER_PI = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
|
||||
TWO_OVER_PI = int(
|
||||
"0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd"
|
||||
"63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414"
|
||||
"da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6", 16)
|
||||
# TWO_OVER_PI as 19 u64 words for trig_preop_result (word[0] = bits 0-63, word[18] = bits 1152-1200)
|
||||
_PREOP_WORDS = tuple((TWO_OVER_PI >> (64 * i)) & 0xFFFFFFFFFFFFFFFF for i in range(19))
|
||||
def _trig_preop(val: UOp) -> UOp:
|
||||
|
|
@ -247,10 +254,14 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
|
|||
# Normalization conversions: map [-1,1] or [0,1] to integer range
|
||||
# Use floor(x + 0.5) for round-to-nearest
|
||||
# SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior)
|
||||
'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f16_to_snorm': lambda a: _floor(
|
||||
_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f16_to_unorm': lambda a: _floor(
|
||||
_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_snorm': lambda a: _floor(
|
||||
a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f32_to_unorm': lambda a: _floor(
|
||||
a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8),
|
||||
# Integer truncation conversions
|
||||
'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16),
|
||||
|
|
@ -310,21 +321,35 @@ _SINGLE_CHAR = {'(': 'LPAREN', ')': 'RPAREN', '[': 'LBRACKET', ']': 'RBRACKET',
|
|||
|
||||
class Token:
|
||||
__slots__ = ('type', 'val')
|
||||
def __init__(self, type: str, val: str): self.type, self.val = type, val
|
||||
def __init__(self, kind: str, val: str): self.type, self.val = kind, val
|
||||
def __repr__(self): return f'{self.type}:{self.val}'
|
||||
|
||||
def tokenize(s: str) -> list[Token]:
|
||||
tokens, i, n = [], 0, len(s)
|
||||
while i < n:
|
||||
c = s[i]
|
||||
if c.isspace(): i += 1; continue
|
||||
if c.isspace():
|
||||
i += 1
|
||||
continue
|
||||
if i + 1 < n and s[i:i+2] in ('+=', '-='):
|
||||
tokens.append(Token('ASSIGN_OP', s[i:i+2])); i += 2; continue
|
||||
tokens.append(Token('ASSIGN_OP', s[i:i+2]))
|
||||
i += 2
|
||||
continue
|
||||
if i + 1 < n and s[i:i+2] in ('||', '&&', '>=', '<=', '==', '!=', '<>', '>>', '<<', '**', '+:', '-:'):
|
||||
tokens.append(Token('OP', s[i:i+2])); i += 2; continue
|
||||
if c in '|^&><+-*/~!%': tokens.append(Token('OP', c)); i += 1; continue
|
||||
if (t := _SINGLE_CHAR.get(c)): tokens.append(Token(t, c)); i += 1; continue
|
||||
if c == ';': i += 1; continue
|
||||
tokens.append(Token('OP', s[i:i+2]))
|
||||
i += 2
|
||||
continue
|
||||
if c in '|^&><+-*/~!%':
|
||||
tokens.append(Token('OP', c))
|
||||
i += 1
|
||||
continue
|
||||
if (t := _SINGLE_CHAR.get(c)):
|
||||
tokens.append(Token(t, c))
|
||||
i += 1
|
||||
continue
|
||||
if c == ';':
|
||||
i += 1
|
||||
continue
|
||||
if c.isdigit() or (c == '-' and i + 1 < n and s[i+1].isdigit()):
|
||||
start = i
|
||||
if c == '-': i += 1
|
||||
|
|
@ -337,31 +362,38 @@ def tokenize(s: str) -> list[Token]:
|
|||
i += 1
|
||||
while i < n and s[i].isdigit(): i += 1
|
||||
for sfx in ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f'):
|
||||
if s[i:i+len(sfx)] == sfx: i += len(sfx); break
|
||||
tokens.append(Token('NUM', s[start:i])); continue
|
||||
if s[i:i+len(sfx)] == sfx:
|
||||
i += len(sfx)
|
||||
break
|
||||
tokens.append(Token('NUM', s[start:i]))
|
||||
continue
|
||||
if c.isalpha() or c == '_':
|
||||
start = i
|
||||
while i < n and (s[i].isalnum() or s[i] == '_'): i += 1
|
||||
tokens.append(Token('IDENT', s[start:i])); continue
|
||||
tokens.append(Token('IDENT', s[start:i]))
|
||||
continue
|
||||
raise RuntimeError(f"unexpected char '{c}' at pos {i} in: {s}")
|
||||
tokens.append(Token('EOF', ''))
|
||||
return tokens
|
||||
|
||||
class Parser:
|
||||
def __init__(self, tokens: list[Token], vars: dict, funcs: dict | None = None):
|
||||
self.tokens, self.vars, self.funcs, self.pos = tokens, vars, funcs if funcs is not None else _FUNCS, 0
|
||||
def __init__(self, tokens: list[Token], env: dict, funcs: dict | None = None):
|
||||
self.tokens, self.vars, self.funcs, self.pos = tokens, env, funcs if funcs is not None else _FUNCS, 0
|
||||
|
||||
def peek(self, offset=0) -> Token: return self.tokens[min(self.pos + offset, len(self.tokens) - 1)]
|
||||
def at(self, *types) -> bool: return self.peek().type in types
|
||||
def _advance(self) -> Token: tok = self.tokens[self.pos]; self.pos += 1; return tok
|
||||
def eat(self, type: str) -> Token:
|
||||
if self.peek().type != type: raise RuntimeError(f"expected {type}, got {self.peek()}")
|
||||
def _advance(self) -> Token:
|
||||
tok = self.tokens[self.pos]
|
||||
self.pos += 1
|
||||
return tok
|
||||
def eat(self, kind: str) -> Token:
|
||||
if self.peek().type != kind: raise RuntimeError(f"expected {kind}, got {self.peek()}")
|
||||
return self._advance()
|
||||
def try_eat(self, type: str) -> Token | None: return self._advance() if self.peek().type == type else None
|
||||
def try_eat_val(self, val: str, type: str) -> Token | None:
|
||||
return self._advance() if self.peek().type == type and self.peek().val == val else None
|
||||
def eat_val(self, val: str, type: str) -> Token:
|
||||
if self.peek().type != type or self.peek().val != val: raise RuntimeError(f"expected {type}:{val}, got {self.peek()}")
|
||||
def try_eat(self, kind: str) -> Token | None: return self._advance() if self.peek().type == kind else None
|
||||
def try_eat_val(self, val: str, kind: str) -> Token | None:
|
||||
return self._advance() if self.peek().type == kind and self.peek().val == val else None
|
||||
def eat_val(self, val: str, kind: str) -> Token:
|
||||
if self.peek().type != kind or self.peek().val != val: raise RuntimeError(f"expected {kind}:{val}, got {self.peek()}")
|
||||
return self._advance()
|
||||
|
||||
def parse(self) -> UOp:
|
||||
|
|
@ -381,8 +413,10 @@ class Parser:
|
|||
case '&&' | '&': return left & right
|
||||
case '^': return left ^ right
|
||||
case '==' | '<>': return left.eq(right) if op == '==' else left.ne(right)
|
||||
case '!=' : return left.ne(right)
|
||||
case '>=' | '<=' | '>' | '<': return self._cmp_nan(left, right, {'>=':(lambda a,b:a>=b),'<=':(lambda a,b:a<=b),'>':(lambda a,b:a>b),'<':(lambda a,b:a<b)}[op])
|
||||
case '!=': return left.ne(right)
|
||||
case '>=' | '<=' | '>' | '<':
|
||||
ops = {'>=':(lambda a,b:a>=b),'<=':(lambda a,b:a<=b),'>':(lambda a,b:a>b),'<':(lambda a,b:a<b)}
|
||||
return self._cmp_nan(left, right, ops[op])
|
||||
case '>>' | '<<': return (left >> right) if op == '>>' else (left << right)
|
||||
case '+' | '-':
|
||||
if op == '-' and left.op == Ops.CONST and right.op == Ops.CONST: return _const(left.dtype, left.arg - right.arg)
|
||||
|
|
@ -529,7 +563,8 @@ class Parser:
|
|||
if dt is None: return base
|
||||
if dt == base.dtype: return base
|
||||
if dt.itemsize == 2 and base.dtype.itemsize == 4:
|
||||
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
|
||||
if dt == dtypes.uint16: return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16)
|
||||
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
|
||||
if field == 'i4': return _signext_4bit(base)
|
||||
return _cast_to(base, dt)
|
||||
|
||||
|
|
@ -539,7 +574,7 @@ class Parser:
|
|||
|
||||
def _handle_bracket_rest(self, first: UOp, base: UOp, var_name: str | None = None) -> UOp:
|
||||
if self.at('OP') and self.peek().val in ('+:', '-:'):
|
||||
op = self.eat('OP').val
|
||||
self.eat('OP')
|
||||
width = self.parse()
|
||||
self.eat('RBRACKET')
|
||||
if width.op == Ops.CONST:
|
||||
|
|
@ -625,7 +660,8 @@ class Parser:
|
|||
inner = self.parse()
|
||||
self.eat('RPAREN')
|
||||
dt = {('U',32): dtypes.uint32, ('U',64): dtypes.uint64, ('I',32): dtypes.int, ('I',64): dtypes.int64,
|
||||
('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64, ('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32)
|
||||
('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64,
|
||||
('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32)
|
||||
if type_char == 'F' and inner.dtype in (dtypes.uint32, dtypes.uint64, dtypes.ulong, dtypes.int, dtypes.int64):
|
||||
if inner.dtype.itemsize != dt.itemsize: inner = inner.cast(dtypes.uint32 if dt.itemsize == 4 else dtypes.uint64)
|
||||
return inner.bitcast(dt)
|
||||
|
|
@ -686,7 +722,7 @@ class Parser:
|
|||
def _call_func(self, name: str, args: list[UOp]) -> UOp:
|
||||
if name in self.vars and isinstance(self.vars[name], tuple) and self.vars[name][0] == 'lambda':
|
||||
_, params, body = self.vars[name]
|
||||
lv = {**self.vars, **{p: a for p, a in zip(params, args)}}
|
||||
lv = {**self.vars, **dict(zip(params, args))}
|
||||
if ';' in body or '\n' in body or 'return' in body.lower():
|
||||
lines = [l.strip() for l in body.replace(';', '\n').split('\n') if l.strip() and not l.strip().startswith('//')]
|
||||
_, _, result = parse_block(lines, 0, lv, self.funcs)
|
||||
|
|
@ -712,7 +748,9 @@ class Parser:
|
|||
elif dt in (dtypes.uint8, dtypes.int8):
|
||||
val = mem.index(idx, *gate, ptr=True).load().cast(dt)
|
||||
elif dt in (dtypes.uint16, dtypes.int16, dtypes.short):
|
||||
val = (mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32) | (mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(8))).cast(dt)
|
||||
lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32)
|
||||
hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32)
|
||||
val = (lo | (hi << _u32(8))).cast(dt)
|
||||
else:
|
||||
val = _u32(0)
|
||||
for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
|
||||
|
|
@ -723,7 +761,8 @@ class Parser:
|
|||
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
|
||||
val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32))
|
||||
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
|
||||
elif dt in (dtypes.uint16, dtypes.int16): val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
|
||||
elif dt in (dtypes.uint16, dtypes.int16):
|
||||
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
|
||||
return val
|
||||
|
||||
def _coerce_cmp(self, l: UOp, r: UOp) -> tuple[UOp, UOp]:
|
||||
|
|
@ -756,8 +795,8 @@ def _match_bracket(toks: list[Token], start: int) -> tuple[int, list[Token]]:
|
|||
return j, [t for t in toks[start+1:j-1] if t.type != 'EOF']
|
||||
|
||||
def _tok_str(toks: list[Token]) -> str: return ' '.join(t.val for t in toks if t.type != 'EOF')
|
||||
def parse_tokens(toks: list[Token], vars: dict[str, VarVal], funcs: dict | None = None) -> UOp:
|
||||
return Parser(toks, vars, funcs).parse()
|
||||
def parse_tokens(toks: list[Token], env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
|
||||
return Parser(toks, env, funcs).parse()
|
||||
|
||||
# Unified block parser for pcode
|
||||
def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
|
||||
|
|
@ -781,7 +820,7 @@ def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str =
|
|||
if depth == 0: return j
|
||||
return len(s)
|
||||
|
||||
def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: dict | None = None,
|
||||
def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dict | None = None,
|
||||
assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]:
|
||||
"""Parse a block of pcode. Returns (next_line, block_assigns, return_value).
|
||||
If assigns list is provided, side effects (MEM/VGPR writes) are appended to it."""
|
||||
|
|
@ -792,7 +831,9 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
while i < len(lines):
|
||||
line = lines[i]
|
||||
toks = tokenize(line)
|
||||
if toks[0].type != 'IDENT' and toks[0].type != 'LBRACE': i += 1; continue
|
||||
if toks[0].type != 'IDENT' and toks[0].type != 'LBRACE':
|
||||
i += 1
|
||||
continue
|
||||
first = toks[0].val.lower() if toks[0].type == 'IDENT' else '{'
|
||||
|
||||
# Block terminators
|
||||
|
|
@ -801,17 +842,19 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
# return expr (lambda bodies)
|
||||
if first == 'return':
|
||||
rest = line[line.lower().find('return') + 6:].strip()
|
||||
return i + 1, block_assigns, parse_expr(rest, vars, funcs)
|
||||
return i + 1, block_assigns, parse_expr(rest, env, funcs)
|
||||
|
||||
# for loop
|
||||
if first == 'for':
|
||||
# Parse: for VAR in [SIZE']START : [SIZE']END do
|
||||
p = Parser(toks, vars, funcs)
|
||||
p = Parser(toks, env, funcs)
|
||||
p.eat_val('for', 'IDENT')
|
||||
loop_var = p.eat('IDENT').val
|
||||
p.eat_val('in', 'IDENT')
|
||||
def parse_bound():
|
||||
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
|
||||
if p.at('NUM') and p.peek(1).type == 'QUOTE':
|
||||
p.eat('NUM')
|
||||
p.eat('QUOTE')
|
||||
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
|
||||
expr = p.parse().simplify()
|
||||
assert expr.op == Ops.CONST, f"loop bound must be constant, got {expr}"
|
||||
|
|
@ -833,38 +876,41 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
# Execute loop with break support
|
||||
has_break = any('break' in bl.lower() for bl in body_lines)
|
||||
found_var = f'_found_{id(body_lines)}' if has_break else None
|
||||
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
|
||||
if found_var: env[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
|
||||
for loop_i in range(start_val, end_val + 1):
|
||||
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, {**env, **block_assigns}, funcs, assigns)
|
||||
if has_break:
|
||||
assert found_var is not None
|
||||
found = block_assigns.get(found_var, vars.get(found_var))
|
||||
found = block_assigns.get(found_var, env.get(found_var))
|
||||
assert isinstance(found, UOp)
|
||||
not_found = found.eq(_const(dtypes.bool, False))
|
||||
for var, val in iter_assigns.items():
|
||||
if var != found_var and isinstance(val, UOp):
|
||||
old = block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
old = block_assigns.get(var, env.get(var, _u32(0)))
|
||||
if isinstance(old, UOp):
|
||||
block_assigns[var] = vars[var] = not_found.where(val, old.cast(val.dtype) if val.dtype != old.dtype and val.dtype.itemsize == old.dtype.itemsize else old)
|
||||
block_assigns[var] = env[var] = not_found.where(
|
||||
val, old.cast(val.dtype) if val.dtype != old.dtype and val.dtype.itemsize == old.dtype.itemsize else old)
|
||||
for j, bl in enumerate(body_lines):
|
||||
bl_l = bl.strip().lower()
|
||||
if bl_l.startswith('if ') and bl_l.endswith(' then'):
|
||||
if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))):
|
||||
cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i)
|
||||
cond = _to_bool(parse_expr(cond_str, vars, funcs))
|
||||
block_assigns[found_var] = vars[found_var] = not_found.where(cond, found)
|
||||
cond = _to_bool(parse_expr(cond_str, env, funcs))
|
||||
block_assigns[found_var] = env[found_var] = not_found.where(cond, found)
|
||||
break
|
||||
else:
|
||||
block_assigns.update(iter_assigns); vars.update(iter_assigns)
|
||||
block_assigns.update(iter_assigns)
|
||||
env.update(iter_assigns)
|
||||
continue
|
||||
|
||||
# declare
|
||||
if first == 'declare':
|
||||
# Initialize scalar declarations (skip arrays and vars already passed as srcs)
|
||||
# Initialize scalar declarations (skip arrays and env already passed as srcs)
|
||||
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT':
|
||||
vars.setdefault(toks[1].val, _u32(0))
|
||||
i += 1; continue
|
||||
env.setdefault(toks[1].val, _u32(0))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# lambda definition
|
||||
if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks):
|
||||
|
|
@ -886,26 +932,30 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
if ch == '(': depth += 1
|
||||
elif ch == ')':
|
||||
depth -= 1
|
||||
if depth == 0: body_lines_lst.append(lines[i][:j]); break
|
||||
if depth == 0:
|
||||
body_lines_lst.append(lines[i][:j])
|
||||
break
|
||||
else: body_lines_lst.append(lines[i])
|
||||
i += 1
|
||||
body = '\n'.join(body_lines_lst).strip()
|
||||
vars[name] = ('lambda', params, body)
|
||||
env[name] = ('lambda', params, body)
|
||||
continue
|
||||
|
||||
# MEM assignment: MEM[addr].type (+|-)?= value
|
||||
if first == 'mem' and toks[1].type == 'LBRACKET':
|
||||
j, addr_toks = _match_bracket(toks, 1)
|
||||
addr = parse_tokens(addr_toks, vars, funcs)
|
||||
addr = parse_tokens(addr_toks, env, funcs)
|
||||
if j < len(toks) and toks[j].type == 'DOT': j += 1
|
||||
dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32'
|
||||
dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1
|
||||
compound_op = None
|
||||
if j < len(toks) and toks[j].type == 'ASSIGN_OP': compound_op = toks[j].val; j += 1
|
||||
if j < len(toks) and toks[j].type == 'ASSIGN_OP':
|
||||
compound_op = toks[j].val
|
||||
j += 1
|
||||
elif j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
rhs = parse_tokens(toks[j:], vars, funcs)
|
||||
rhs = parse_tokens(toks[j:], env, funcs)
|
||||
if compound_op:
|
||||
mem = vars.get('_vmem') if '_vmem' in vars else vars.get('_lds')
|
||||
mem = env.get('_vmem') if '_vmem' in env else env.get('_lds')
|
||||
if isinstance(mem, UOp):
|
||||
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
|
||||
idx = (addr >> _const(adt, 2)).cast(dtypes.int)
|
||||
|
|
@ -914,7 +964,8 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
old = old.cast(dtypes.uint64) | (mem.index(((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)).cast(dtypes.uint64) << _u64(32))
|
||||
rhs = (old + rhs) if compound_op == '+=' else (old - rhs)
|
||||
if assigns is not None: assigns.append((f'MEM[{_tok_str(addr_toks)}].{dt_name}', (addr, rhs)))
|
||||
i += 1; continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# VGPR assignment: VGPR[lane][reg] = value
|
||||
if first == 'vgpr' and toks[1].type == 'LBRACKET':
|
||||
|
|
@ -923,9 +974,12 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
j, reg_toks = _match_bracket(toks, j)
|
||||
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs)
|
||||
if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
|
||||
i += 1; continue
|
||||
ln = parse_tokens(lane_toks, env, funcs)
|
||||
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
|
||||
if assigns is not None:
|
||||
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Compound destination: {hi.type, lo.type} = value
|
||||
if first == '{':
|
||||
|
|
@ -939,18 +993,20 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
j += 3
|
||||
if j < len(toks) and toks[j].type == 'RBRACE': j += 1
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
val = parse_tokens(toks[j:], vars, funcs)
|
||||
val = parse_tokens(toks[j:], env, funcs)
|
||||
lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32)
|
||||
lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32
|
||||
lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt)
|
||||
hi_val = (val >> _const(val.dtype, lo_bits)).cast(hi_dt)
|
||||
block_assigns[lo_var] = vars[lo_var] = lo_val
|
||||
block_assigns[hi_var] = vars[hi_var] = hi_val
|
||||
block_assigns[lo_var] = env[lo_var] = lo_val
|
||||
block_assigns[hi_var] = env[hi_var] = hi_val
|
||||
if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)])
|
||||
i += 1; continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and (toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and \
|
||||
(toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
|
||||
bracket_start = 2 if toks[1].type == 'LBRACKET' else 4
|
||||
j = bracket_start
|
||||
colon_pos = None
|
||||
|
|
@ -967,23 +1023,28 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
j += 1
|
||||
if j < len(toks) and toks[j].type == 'DOT': j += 2
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
val = parse_tokens(toks[j:], vars, funcs)
|
||||
val = parse_tokens(toks[j:], env, funcs)
|
||||
dt_suffix = toks[2].val if toks[1].type == 'DOT' else None
|
||||
if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val))
|
||||
if var not in vars: vars[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
|
||||
old = block_assigns.get(var, vars.get(var))
|
||||
block_assigns[var] = vars[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
|
||||
i += 1; continue
|
||||
except: pass
|
||||
if var not in env: env[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
|
||||
old = block_assigns.get(var, env.get(var))
|
||||
assert isinstance(old, UOp)
|
||||
block_assigns[var] = env[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
|
||||
i += 1
|
||||
continue
|
||||
except Exception: pass
|
||||
elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...])
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)):
|
||||
existing = block_assigns.get(var, env.get(var))
|
||||
if existing is not None and isinstance(existing, UOp) and \
|
||||
not any(f'{var}{k}' in env or f'{var}{k}' in block_assigns for k in range(8)):
|
||||
bit_toks = toks[2:j]
|
||||
j += 1
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs))
|
||||
i += 1; continue
|
||||
block_assigns[var] = env[var] = _set_bit(
|
||||
existing, _to_u32(parse_tokens(bit_toks, env, funcs)), parse_tokens(toks[j+1:], env, funcs))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Array element: var[idx] = value (static index) or var[expr] = value (dynamic)
|
||||
if len(toks) >= 4 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
|
||||
|
|
@ -993,80 +1054,90 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
# Static index: var[NUM] = value
|
||||
if len(idx_toks) == 1 and idx_toks[0].type == 'NUM':
|
||||
idx = int(idx_toks[0].val.rstrip('UuLl'))
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
val = parse_tokens(toks[j+1:], env, funcs)
|
||||
existing = block_assigns.get(var, env.get(var))
|
||||
if existing is not None and isinstance(existing, UOp):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
|
||||
block_assigns[var] = env[var] = _set_bit(existing, _u32(idx), val)
|
||||
else:
|
||||
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
|
||||
i += 1; continue
|
||||
block_assigns[f'{var}@{idx}'] = env[f'{var}@{idx}'] = val
|
||||
i += 1
|
||||
continue
|
||||
# Dynamic index: var[expr] = value where var has @-elements
|
||||
elems = [(k.split('@')[1], v) for k, v in {**vars, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
|
||||
elems = [(k.split('@')[1], v) for k, v in {**env, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
|
||||
if elems:
|
||||
idx_expr = parse_tokens(idx_toks, vars, funcs)
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
idx_expr = parse_tokens(idx_toks, env, funcs)
|
||||
val = parse_tokens(toks[j+1:], env, funcs)
|
||||
for elem_idx_str, old_elem in elems:
|
||||
elem_idx = int(elem_idx_str)
|
||||
cond = _to_u32(idx_expr).eq(_u32(elem_idx))
|
||||
new_val = cond.where(val.cast(old_elem.dtype) if val.dtype != old_elem.dtype else val, old_elem)
|
||||
block_assigns[f'{var}@{elem_idx}'] = vars[f'{var}@{elem_idx}'] = new_val
|
||||
i += 1; continue
|
||||
block_assigns[f'{var}@{elem_idx}'] = env[f'{var}@{elem_idx}'] = new_val
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Compound assignment: var += or var -=
|
||||
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
|
||||
if assign_op is not None:
|
||||
var = toks[0].val
|
||||
old = block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
rhs = parse_tokens(toks[assign_op+1:], vars, funcs)
|
||||
old = block_assigns.get(var, env.get(var, _u32(0)))
|
||||
rhs = parse_tokens(toks[assign_op+1:], env, funcs)
|
||||
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
|
||||
block_assigns[var] = vars[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
|
||||
i += 1; continue
|
||||
block_assigns[var] = env[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Typed element: var.type[idx] = value
|
||||
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
|
||||
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and \
|
||||
toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
|
||||
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
|
||||
dt = DTYPES.get(dt_name, dtypes.uint32)
|
||||
j = 6
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val, old = parse_tokens(toks[j+1:], vars, funcs), block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
val, old = parse_tokens(toks[j+1:], env, funcs), block_assigns.get(var, env.get(var, _u32(0)))
|
||||
bw = dt.itemsize * 8
|
||||
block_assigns[var] = vars[var] = _set_bits(old, val, bw, idx * bw)
|
||||
block_assigns[var] = env[var] = _set_bits(old, val, bw, idx * bw)
|
||||
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
|
||||
i += 1; continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Dynamic bit: var.type[expr_with_brackets] = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and \
|
||||
toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
|
||||
j, depth, has_inner = 4, 1, False
|
||||
while j < len(toks) and depth > 0:
|
||||
if toks[j].type == 'LBRACKET': depth += 1; has_inner = True
|
||||
if toks[j].type == 'LBRACKET':
|
||||
depth += 1
|
||||
has_inner = True
|
||||
elif toks[j].type == 'RBRACKET': depth -= 1
|
||||
j += 1
|
||||
if has_inner:
|
||||
var = toks[0].val
|
||||
bit_pos = _to_u32(parse_tokens(toks[4:j-1], vars, funcs))
|
||||
bit_pos = _to_u32(parse_tokens(toks[4:j-1], env, funcs))
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
old = block_assigns.get(var, vars.get(var, _u32(0)))
|
||||
block_assigns[var] = vars[var] = _set_bit(old, bit_pos, val)
|
||||
i += 1; continue
|
||||
val = parse_tokens(toks[j+1:], env, funcs)
|
||||
old = block_assigns.get(var, env.get(var, _u32(0)))
|
||||
block_assigns[var] = env[var] = _set_bit(old, bit_pos, val)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
|
||||
if first == 'if':
|
||||
def parse_cond(s, kw):
|
||||
ll = s.lower()
|
||||
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs))
|
||||
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), env, funcs))
|
||||
def is_const(c, v): return c.op == Ops.CONST and c.arg is v
|
||||
cond = parse_cond(line, 'if')
|
||||
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else []
|
||||
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
|
||||
vars_snap = dict(vars)
|
||||
env_snap = dict(env)
|
||||
static_true = is_const(cond, True) # track if any condition is statically true
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not is_const(cond, False) else None)
|
||||
i, branch, ret = parse_block(lines, i, env, funcs, assigns if not is_const(cond, False) else None)
|
||||
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
env.clear()
|
||||
env.update(env_snap)
|
||||
while i < len(lines):
|
||||
ltoks = tokenize(lines[i])
|
||||
if ltoks[0].type != 'IDENT': break
|
||||
|
|
@ -1074,17 +1145,22 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
if lf == 'elsif':
|
||||
c = parse_cond(lines[i], 'elsif')
|
||||
take = not static_true and not is_const(c, False)
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns if take else None)
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, env, funcs, assigns if take else None)
|
||||
if take:
|
||||
conditions.append((c, ret if ret is not None else branch))
|
||||
if is_const(c, True): static_true = True
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
env.clear()
|
||||
env.update(env_snap)
|
||||
elif lf == 'else':
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not static_true else None)
|
||||
i, branch, ret = parse_block(lines, i, env, funcs, assigns if not static_true else None)
|
||||
if not static_true: else_branch = (ret, branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'endif': i += 1; break
|
||||
env.clear()
|
||||
env.update(env_snap)
|
||||
elif lf == 'endif':
|
||||
i += 1
|
||||
break
|
||||
else: break
|
||||
# Check if any branch returned a value (lambda-style)
|
||||
if any(isinstance(br, UOp) for _, br in conditions):
|
||||
|
|
@ -1097,18 +1173,19 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
# If statically true, use that branch directly; otherwise merge with WHERE
|
||||
if static_true:
|
||||
ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {})
|
||||
block_assigns.update(ba); vars.update(ba)
|
||||
block_assigns.update(ba)
|
||||
env.update(ba)
|
||||
else:
|
||||
else_assigns = else_branch[1]
|
||||
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
|
||||
for var in all_vars:
|
||||
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
|
||||
for cond, ba in reversed(conditions):
|
||||
res: Any = else_assigns.get(var, block_assigns.get(var, env.get(var, _u32(0))))
|
||||
for cond, ba in reversed(conditions): # type: ignore[assignment]
|
||||
if isinstance(ba, dict) and var in ba:
|
||||
tv = ba[var]
|
||||
if isinstance(tv, UOp) and isinstance(res, UOp):
|
||||
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
|
||||
block_assigns[var] = vars[var] = res
|
||||
block_assigns[var] = env[var] = res
|
||||
continue
|
||||
|
||||
# Regular assignment: var = value
|
||||
|
|
@ -1116,11 +1193,12 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
if t.type == 'EQUALS':
|
||||
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
|
||||
base_var = toks[0].val
|
||||
block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], vars, funcs)
|
||||
i += 1; break
|
||||
block_assigns[base_var] = env[base_var] = parse_tokens(toks[j+1:], env, funcs)
|
||||
i += 1
|
||||
break
|
||||
else: i += 1
|
||||
return i, block_assigns, None
|
||||
|
||||
def parse_expr(expr: str, vars: dict[str, VarVal], funcs: dict | None = None) -> UOp:
|
||||
return parse_tokens(tokenize(expr.strip().rstrip(';')), vars, funcs)
|
||||
def parse_expr(expr: str, env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
|
||||
return parse_tokens(tokenize(expr.strip().rstrip(';')), env, funcs)
|
||||
|
||||
|
|
|
|||
|
|
@ -125,8 +125,8 @@ class PacketType:
|
|||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._fields = {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField)}
|
||||
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4)
|
||||
cls._fields = {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField)} # type: ignore[attr-defined]
|
||||
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4) # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: int, time: int = 0):
|
||||
|
|
@ -135,7 +135,7 @@ class PacketType:
|
|||
return inst
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fields_str = ", ".join(f"{k}={getattr(self, k)}" for k in self._fields if not k.startswith('_') and k != 'encoding')
|
||||
fields_str = ", ".join(f"{k}={getattr(self, k)}" for k in self._fields if not k.startswith('_') and k != 'encoding') # type: ignore[attr-defined]
|
||||
return f"{self.__class__.__name__}({fields_str})"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -514,7 +514,7 @@ def _build_decode_tables(packet_types: dict[int, type[PacketType]]) -> tuple[dic
|
|||
for opcode, pkt_cls in packet_types.items():
|
||||
delta_field = getattr(pkt_cls, 'delta', None)
|
||||
special = _special.get(pkt_cls, 0)
|
||||
decode_info[opcode] = (pkt_cls, pkt_cls._size_nibbles, delta_field.lo if delta_field else 0, delta_field.mask if delta_field else 0, special)
|
||||
decode_info[opcode] = (pkt_cls, pkt_cls._size_nibbles, delta_field.lo if delta_field else 0, delta_field.mask if delta_field else 0, special) # type: ignore[attr-defined]
|
||||
return decode_info, state_table
|
||||
|
||||
_DECODE_INFO_RDNA3, _STATE_TABLE_RDNA3 = _build_decode_tables(PACKET_TYPES_RDNA3)
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
|
|||
for i, buf in enumerate(lowered.bufs):
|
||||
if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'):
|
||||
try: buf_data[i] = bytes(buf.base._buf)
|
||||
except: pass
|
||||
except Exception: pass
|
||||
# Extract rsrc2 from ELF (same as ops_amd.py)
|
||||
group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
|
||||
lds_size = ((group_segment_size + 511) // 512) & 0x1FF
|
||||
|
|
@ -232,7 +232,8 @@ def main():
|
|||
total_work = n_insts * n_workgroups * n_threads
|
||||
|
||||
print(f"{n_insts} insts ({n_compiled} unique) × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
|
||||
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None
|
||||
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size,
|
||||
args_ptr, rsrc2, args.iterations) if rust_remu else None
|
||||
|
||||
if py_compile is not None:
|
||||
py_exec_rate = total_work / py_exec / 1e6
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
# RDNA3/RDNA4/CDNA disassembler
|
||||
from __future__ import annotations
|
||||
import re, struct
|
||||
import re
|
||||
from typing import Callable
|
||||
from extra.assembly.amd.dsl import Inst, Reg
|
||||
|
||||
# Special register mappings for disassembly
|
||||
SPECIAL_GPRS = {106: 'vcc_lo', 107: 'vcc_hi', 124: 'null', 125: 'm0', 126: 'exec_lo', 127: 'exec_hi',
|
||||
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0', 246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc'}
|
||||
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0',
|
||||
246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc'}
|
||||
SPECIAL_GPRS_CDNA = {106: 'vcc_lo', 107: 'vcc_hi', 124: 'm0', 126: 'exec_lo', 127: 'exec_hi',
|
||||
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0', 246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc',
|
||||
128: '0', 240: '0.5', 241: '-0.5', 242: '1.0', 243: '-1.0', 244: '2.0', 245: '-2.0',
|
||||
246: '4.0', 247: '-4.0', 248: '0x3e22f983', 253: 'scc',
|
||||
102: 'flat_scratch_lo', 103: 'flat_scratch_hi', 104: 'xnack_mask_lo', 105: 'xnack_mask_hi',
|
||||
251: 'src_vccz', 252: 'src_execz'}
|
||||
SPECIAL_PAIRS = {106: 'vcc', 126: 'exec'}
|
||||
|
|
@ -70,7 +72,9 @@ def _num_srcs(inst) -> int:
|
|||
if any(x in n for x in ('FMA', 'MAD', 'CNDMASK', 'BFE', 'BFI', 'LERP', 'MED3', 'SAD', 'DIV_FMAS', 'DIV_FIXUP', 'DIV_SCALE', 'CUBE')): return 3
|
||||
# PERMLANE_VAR ops are 2-source, but PERMLANE (non-VAR) are 3-source
|
||||
if 'PERMLANE' in n and '_VAR' not in n: return 3
|
||||
if any(x in n for x in ('_ADD3', '_LSHL_ADD', '_ADD_LSHL', '_LSHL_OR', '_AND_OR', 'OR3_B32', 'AND_OR_B32', 'ALIGNBIT', 'ALIGNBYTE', 'V_PERM_', 'XOR3', 'XAD', 'MULLIT', 'MINMAX', 'MAXMIN', 'MINIMUMMAXIMUM', 'MAXIMUMMINIMUM', 'MINIMUM3', 'MAXIMUM3', 'MIN3', 'MAX3', 'DOT2', 'CVT_PK_U8_F32', 'DOT4', 'DOT8', 'WMMA', 'SWMMAC')): return 3
|
||||
if any(x in n for x in ('_ADD3', '_LSHL_ADD', '_ADD_LSHL', '_LSHL_OR', '_AND_OR', 'OR3_B32', 'AND_OR_B32', 'ALIGNBIT',
|
||||
'ALIGNBYTE', 'V_PERM_', 'XOR3', 'XAD', 'MULLIT', 'MINMAX', 'MAXMIN', 'MINIMUMMAXIMUM', 'MAXIMUMMINIMUM',
|
||||
'MINIMUM3', 'MAXIMUM3', 'MIN3', 'MAX3', 'DOT2', 'CVT_PK_U8_F32', 'DOT4', 'DOT8', 'WMMA', 'SWMMAC')): return 3
|
||||
return 2
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -80,13 +84,14 @@ def _num_srcs(inst) -> int:
|
|||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP1_SDST_LIT, VOP1_LIT, VOP2, VOP2_LIT, VOP3, VOP3_SDST, VOP3_SDST_LIT,
|
||||
VOP3_LIT, VOP3SD, VOP3SD_LIT, VOP3P, VOP3P_LIT, VOPC, VOPC_LIT, VOPD, VOPD_LIT, VINTERP, SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPC_LIT,
|
||||
SOPK, SOPK_LIT, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH, VOP2Op, VOPDOp, SOPPOp, HWREG, MSG)
|
||||
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP1_SDST_LIT as R4_VOP1_SDST_LIT, VOP1_LIT as R4_VOP1_LIT,
|
||||
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST,
|
||||
VOP1_SDST_LIT as R4_VOP1_SDST_LIT, VOP1_LIT as R4_VOP1_LIT,
|
||||
VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3_SDST_LIT as R4_VOP3_SDST_LIT, VOP3_LIT as R4_VOP3_LIT,
|
||||
VOP3SD as R4_VOP3SD, VOP3SD_LIT as R4_VOP3SD_LIT, VOP3P as R4_VOP3P, VOP3P_LIT as R4_VOP3P_LIT, VOPC as R4_VOPC, VOPC_LIT as R4_VOPC_LIT,
|
||||
VOPD as R4_VOPD, VOPD_LIT as R4_VOPD_LIT, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT, SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT,
|
||||
SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT, SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP, SMEM as R4_SMEM, DS as R4_DS,
|
||||
VOPDOp as R4_VOPDOp, HWREG as HWREG_RDNA4, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH)
|
||||
from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, HWREG as HWREG_CDNA
|
||||
from extra.assembly.amd.autogen.cdna.ins import HWREG as HWREG_CDNA
|
||||
|
||||
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
|
||||
def _is_r4(inst: Inst) -> bool: return 'rdna4' in inst.__class__.__module__
|
||||
|
|
@ -100,9 +105,15 @@ _CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v
|
|||
|
||||
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{_unwrap(b)}" if n == 1 else f"{p}[{_unwrap(b)}:{_unwrap(b)+n-1}]"
|
||||
def _sreg(b: int, n: int = 1) -> str: return _reg("s", _unwrap(b), n)
|
||||
def _vreg(b: int, n: int = 1) -> str: b = _unwrap(b); return _reg("v", b - 256 if b >= 256 else b, n)
|
||||
def _areg(b: int, n: int = 1) -> str: b = _unwrap(b); return _reg("a", b - 256 if b >= 256 else b, n) # accumulator registers for GFX90a
|
||||
def _ttmp(b, n: int = 1) -> str | None: b = _unwrap(b); return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
|
||||
def _vreg(b: int, n: int = 1) -> str:
|
||||
b = _unwrap(b)
|
||||
return _reg("v", b - 256 if b >= 256 else b, n)
|
||||
def _areg(b: int, n: int = 1) -> str:
|
||||
b = _unwrap(b)
|
||||
return _reg("a", b - 256 if b >= 256 else b, n) # accumulator registers for GFX90a
|
||||
def _ttmp(b, n: int = 1) -> str | None:
|
||||
b = _unwrap(b)
|
||||
return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
|
||||
|
||||
def _fmt_sdst(v, n: int = 1, cdna: bool = False) -> str:
|
||||
v = _unwrap(v)
|
||||
|
|
@ -130,7 +141,9 @@ def _fmt_v16(v, base: int = 256, hi_thresh: int = 384) -> str:
|
|||
|
||||
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
|
||||
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
|
||||
def _src16(inst, v: int) -> str: v = _unwrap(v); return _fmt_v16(v) if v >= 256 else _lit(inst, v) # format 16-bit src: vgpr.h/l or literal
|
||||
def _src16(inst, v: int) -> str:
|
||||
v = _unwrap(v)
|
||||
return _fmt_v16(v) if v >= 256 else _lit(inst, v) # format 16-bit src: vgpr.h/l or literal
|
||||
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
|
||||
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
|
||||
|
||||
|
|
@ -201,7 +214,8 @@ def _disasm_vop2(inst: VOP2) -> str:
|
|||
basename = name.replace('_e32', '')
|
||||
if cdna and basename in _VOP2_CARRY_OUT: return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}"
|
||||
if cdna and basename in _VOP2_CARRY_INOUT: return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}, {vcc}"
|
||||
if not cdna and basename in _VOP2_CARRY_INOUT_RDNA: return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}, {vcc}"
|
||||
if not cdna and basename in _VOP2_CARRY_INOUT_RDNA:
|
||||
return f"{name}{suf} {inst.vdst.fmt()}, {vcc}, {_lit(inst, inst.src0)}, {inst.vsrc1.fmt()}, {vcc}"
|
||||
sn0 = inst.canonical_op_regs.get('s0', 1)
|
||||
if inst.vdst.sz > 1 or sn0 > 1 or inst.vsrc1.sz > 1:
|
||||
src0 = _lit(inst, inst.src0) if inst.src0.offset == 255 else _fmt_src(inst.src0, sn0, cdna)
|
||||
|
|
@ -217,7 +231,10 @@ def _disasm_vopc(inst: VOPC) -> str:
|
|||
return f"{name} vcc, {s0}, {inst.vsrc1.fmt()}" # CDNA VOPC always outputs vcc
|
||||
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
|
||||
has_vcc = 'cmpx' not in name
|
||||
s0 = _lit(inst, inst.src0) if inst.src0.offset == 255 else inst.src0.fmt() if inst.src0.sz > 1 else _src16(inst, inst.src0.offset) if is16 else _lit(inst, inst.src0)
|
||||
if inst.src0.offset == 255: s0 = _lit(inst, inst.src0)
|
||||
elif inst.src0.sz > 1: s0 = inst.src0.fmt()
|
||||
elif is16: s0 = _src16(inst, inst.src0.offset)
|
||||
else: s0 = _lit(inst, inst.src0)
|
||||
s1 = inst.vsrc1.fmt() if inst.vsrc1.sz > 1 else _fmt_v16(inst.vsrc1) if is16 else inst.vsrc1.fmt()
|
||||
suf = "" if name.endswith('_e32') else "_e32"
|
||||
return f"{name}{suf} vcc_lo, {s0}, {s1}" if has_vcc else f"{name}{suf} {s0}, {s1}"
|
||||
|
|
@ -253,10 +270,11 @@ def _disasm_sopp(inst: SOPP) -> str:
|
|||
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
|
||||
if name == 's_delay_alu':
|
||||
deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3']
|
||||
deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2',
|
||||
'TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3']
|
||||
skips = ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
|
||||
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
|
||||
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
|
||||
def dep(v): return deps[v-1] if 0 < v <= len(deps) else str(v)
|
||||
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
|
||||
return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}"
|
||||
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
|
||||
|
|
@ -267,7 +285,7 @@ def _disasm_smem(inst: SMEM) -> str:
|
|||
if name in ('s_gl1_inv', 's_dcache_inv', 's_dcache_inv_vol', 's_dcache_wb', 's_dcache_wb_vol', 's_icache_inv'): return name
|
||||
soe, imm = getattr(inst, 'soe', 0) or getattr(inst, 'soffset_en', 0), getattr(inst, 'imm', 1)
|
||||
is_rdna4 = _is_r4(inst)
|
||||
offset = inst.ioffset if is_rdna4 else getattr(inst, 'offset', 0)
|
||||
offset = inst.ioffset if is_rdna4 else getattr(inst, 'offset', 0) # type: ignore[attr-defined]
|
||||
if cdna:
|
||||
if soe and imm: off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{offset:x}"
|
||||
elif imm: off_s = f"0x{offset:x}"
|
||||
|
|
@ -278,7 +296,9 @@ def _disasm_smem(inst: SMEM) -> str:
|
|||
else: off_s = decode_src(inst.soffset, cdna)
|
||||
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
|
||||
sbase_idx, sbase_count = _unwrap(inst.sbase), 4 if is_buffer else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if sbase_count == 2: sbase_str = _fmt_src(sbase_idx, sbase_count, cdna)
|
||||
elif sbase_idx <= 105: sbase_str = _sreg(sbase_idx, sbase_count)
|
||||
else: sbase_str = _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {_unwrap(inst.sdata)}, {sbase_str}, {off_s}"
|
||||
if 'prefetch' in name:
|
||||
off = getattr(inst, 'ioffset', getattr(inst, 'offset', 0))
|
||||
|
|
@ -312,7 +332,7 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
else: seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
# Global/scratch uses 13-bit signed offset
|
||||
offset = inst.ioffset if r4 else inst.offset
|
||||
offset = inst.ioffset if r4 else inst.offset # type: ignore[attr-defined]
|
||||
if seg != 'flat':
|
||||
if cdna:
|
||||
# CDNA: bit 12 is sign bit but not in offset field
|
||||
|
|
@ -327,19 +347,20 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
regs = inst.canonical_op_regs
|
||||
w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1)
|
||||
off_s = f" offset:{off_val}" if off_val else ""
|
||||
if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}"
|
||||
elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}"
|
||||
if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}" # type: ignore[attr-defined]
|
||||
elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}" # type: ignore[attr-defined]
|
||||
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
if seg == 'flat': saddr_s = ""
|
||||
elif _unwrap(inst.saddr) in (0x7F, 124): saddr_s = ", off"
|
||||
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
|
||||
elif _unwrap(inst.saddr) in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[_unwrap(inst.saddr)]}"
|
||||
elif _unwrap(inst.saddr) in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS):
|
||||
saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[_unwrap(inst.saddr)]}"
|
||||
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if _unwrap(inst.saddr) < 106 else decode_src(_unwrap(inst.saddr), cdna)}"
|
||||
if 'addtid' in name: return f"{instr} {reg_fn(inst.data if 'store' in name else inst.vdst)}{saddr_s}{mods}"
|
||||
# RDNA4: vaddr instead of addr, vsrc instead of data
|
||||
addr = inst.vaddr if r4 else inst.addr
|
||||
data = inst.vsrc if r4 else inst.data
|
||||
# RDNA4: vaddr instead of addr, vsrc instead of data
|
||||
addr = inst.vaddr if r4 else inst.addr # type: ignore[attr-defined]
|
||||
data = inst.vsrc if r4 else inst.data # type: ignore[attr-defined]
|
||||
# load_lds_* instructions: vaddr, saddr (no vdst, data goes to LDS)
|
||||
if 'load_lds' in name:
|
||||
addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2
|
||||
|
|
@ -351,13 +372,14 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w)
|
||||
data_s, vdst_s = reg_fn(data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
if 'atomic' in name:
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc # type: ignore[attr-defined]
|
||||
sfx = f"{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{sfx}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{sfx}"
|
||||
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
|
||||
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
def _disasm_ds(inst: DS) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
name = inst.op_name.lower()
|
||||
acc = getattr(inst, 'acc', 0)
|
||||
reg_fn = _areg if acc else _vreg
|
||||
gds = " gds" if getattr(inst, 'gds', 0) else ""
|
||||
|
|
@ -386,7 +408,8 @@ def _disasm_ds(inst: DS) -> str:
|
|||
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}{off2}{gds}"
|
||||
if 'xchg2' in name: return f"{name} {reg_fn(inst.vdst, regs.get('d', 1))}, {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'load' in name or ('read' in name and 'read2' not in name): return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
|
||||
if 'load' in name or ('read' in name and 'read2' not in name):
|
||||
return f"{name} {reg_fn(inst.vdst)}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
|
||||
if ('store' in name or 'write' in name) and not _has(name, 'cmp', 'xchg', 'write2'):
|
||||
return f"{name} {reg_fn(inst.data0)}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
if 'swizzle' in name or name == 'ds_ordered_count': return f"{name} {reg_fn(inst.vdst)}, {addr}{off}{gds}"
|
||||
|
|
@ -397,13 +420,15 @@ def _disasm_ds(inst: DS) -> str:
|
|||
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
|
||||
def _disasm_vop3(inst: VOP3) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
n_up = name.upper()
|
||||
name = inst.op_name.lower()
|
||||
bits = inst.canonical_op_bits
|
||||
|
||||
# RDNA4 v_s_* scalar VOP3 instructions - vdst is SGPR (VGPRField adds 256)
|
||||
if name.startswith('v_s_'):
|
||||
src = _lit(inst, inst.src0) if _unwrap(inst.src0) == 255 else ("src_scc" if _unwrap(inst.src0) == 253 else _fmt_src(inst.src0, max(1, bits['s0'] // 32)))
|
||||
s0v = _unwrap(inst.src0)
|
||||
if s0v == 255: src = _lit(inst, inst.src0)
|
||||
elif s0v == 253: src = "src_scc"
|
||||
else: src = _fmt_src(inst.src0, max(1, bits['s0'] // 32))
|
||||
if inst.neg & 1: src = f"-{src}"
|
||||
if inst.abs & 1: src = f"|{src}|"
|
||||
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
|
||||
|
|
@ -412,7 +437,6 @@ def _disasm_vop3(inst: VOP3) -> str:
|
|||
|
||||
# Use get_field_bits for register sizes and 16-bit detection
|
||||
r0, r1, r2 = max(1, bits['s0'] // 32), max(1, bits['s1'] // 32), max(1, bits['s2'] // 32)
|
||||
dn = max(1, bits['d'] // 32)
|
||||
is16_d, is16_s, is16_s2 = bits['d'] == 16, bits['s0'] == 16, bits['s2'] == 16
|
||||
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, r0, is16_s)
|
||||
|
|
@ -428,7 +452,8 @@ def _disasm_vop3(inst: VOP3) -> str:
|
|||
|
||||
clamp = getattr(inst, 'cm', None) or getattr(inst, 'clmp', 0)
|
||||
cl, om = " clamp" if clamp else "", _omod(inst.omod)
|
||||
nonvgpr_opsel = (inst.src0.offset < 256 and (inst.opsel & 1)) or (inst.src1.offset < 256 and (inst.opsel & 2)) or (inst.src2.offset < 256 and (inst.opsel & 4))
|
||||
nonvgpr_opsel = ((inst.src0.offset < 256 and (inst.opsel & 1)) or (inst.src1.offset < 256 and (inst.opsel & 2))
|
||||
or (inst.src2.offset < 256 and (inst.opsel & 4)))
|
||||
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
|
||||
|
||||
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
|
||||
|
|
@ -478,7 +503,7 @@ def _disasm_vopd(inst: VOPD) -> str:
|
|||
|
||||
def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
name = inst.op_name.lower()
|
||||
is_wmma, is_swmmac, n, is_fma_mix = 'wmma' in name, 'swmmac' in name, inst.num_srcs() or 2, 'fma_mix' in name
|
||||
is_swmmac, n, is_fma_mix = 'swmmac' in name, inst.num_srcs() or 2, 'fma_mix' in name
|
||||
def get_src(reg):
|
||||
return _lit(inst, reg.offset) if reg.offset == 255 else reg.fmt()
|
||||
src0, src1, src2, dst = get_src(inst.src0), get_src(inst.src1), get_src(inst.src2), inst.vdst.fmt()
|
||||
|
|
@ -487,18 +512,22 @@ def _disasm_vop3p(inst: VOP3P) -> str:
|
|||
if is_fma_mix:
|
||||
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
|
||||
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if clamp else [])
|
||||
mods = (([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else [])
|
||||
+ ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if clamp else []))
|
||||
elif is_swmmac:
|
||||
mods = ([f"index_key:{inst.opsel}"] if inst.opsel else []) + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + \
|
||||
([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
|
||||
else:
|
||||
opsel_hi_default = 7 if n == 3 else 3
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
mods = (([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else [])
|
||||
+ ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else [])
|
||||
+ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else [])
|
||||
+ ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else []))
|
||||
mod_s = ' ' + ' '.join(mods) if mods else ''
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{mod_s}" if n == 3 else f"{name} {dst}, {src0}, {src1}{mod_s}"
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
# Use get_field_bits for register sizes
|
||||
regs = inst.canonical_op_regs
|
||||
dst_regs, src_regs = regs.get('d', 1), regs.get('s0', 1)
|
||||
|
|
@ -512,8 +541,8 @@ def _disasm_sop1(inst: SOP1) -> str:
|
|||
try: msg_str = MSG(v).name if v != 255 else None # MSG_RTN_ILLEGAL_MSG (255) not supported by LLVM
|
||||
except ValueError: msg_str = None
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dst_regs)}, sendmsg({msg_str})" if msg_str else f"{name} {_fmt_sdst(inst.sdst, dst_regs)}, 0x{v:x}"
|
||||
sop1_src_only = ('S_ALLOC_VGPR', 'S_SLEEP_VAR', 'S_BARRIER_SIGNAL', 'S_BARRIER_SIGNAL_ISFIRST', 'S_BARRIER_INIT', 'S_BARRIER_JOIN', 'S_SET_GPR_IDX_IDX',
|
||||
'S_CBRANCH_JOIN')
|
||||
sop1_src_only = ('S_ALLOC_VGPR', 'S_SLEEP_VAR', 'S_BARRIER_SIGNAL', 'S_BARRIER_SIGNAL_ISFIRST',
|
||||
'S_BARRIER_INIT', 'S_BARRIER_JOIN', 'S_SET_GPR_IDX_IDX', 'S_CBRANCH_JOIN')
|
||||
if inst.op_name in sop1_src_only: return f"{name} {src}"
|
||||
if cdna:
|
||||
if 'getpc_b64' in name: return f"{name} {_fmt_sdst(inst.sdst, 2, cdna)}"
|
||||
|
|
@ -551,7 +580,7 @@ _HWREG_BLACKLIST_CDNA = {'HW_REG_PC_LO', 'HW_REG_PC_HI', 'HW_REG_IB_DBG1', 'HW_R
|
|||
'HW_REG_SQ_SHADER_TMA_LO', 'HW_REG_SQ_SHADER_TMA_HI', 'HW_REG_SQ_PERF_SNAPSHOT_DATA', 'HW_REG_SQ_PERF_SNAPSHOT_DATA1',
|
||||
'HW_REG_SQ_PERF_SNAPSHOT_PC_LO', 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI', 'HW_REG_XCC_ID'}
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
is_rdna4 = _is_r4(inst)
|
||||
hw = HWREG_CDNA if cdna else (HWREG_RDNA4 if is_rdna4 else HWREG)
|
||||
blacklist = _HWREG_BLACKLIST_CDNA if cdna else _HWREG_BLACKLIST
|
||||
|
|
@ -574,12 +603,14 @@ def _disasm_sopk(inst: SOPK) -> str:
|
|||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
|
||||
return f"{inst.op_name.lower()} {inst.vdst.fmt()}, {_lit(inst, inst.src0, inst.neg & 1)}, {_lit(inst, inst.src1, inst.neg & 2)}, {_lit(inst, inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
|
||||
s0, s1, s2 = _lit(inst, inst.src0, inst.neg & 1), _lit(inst, inst.src1, inst.neg & 2), _lit(inst, inst.src2, inst.neg & 4)
|
||||
return f"{inst.op_name.lower()} {inst.vdst.fmt()}, {s0}, {s1}, {s2}" + (" " + mods if mods else "")
|
||||
|
||||
DISASM_HANDLERS: dict[type, Callable[..., str]] = {
|
||||
VOP1: _disasm_vop1, VOP1_SDST: _disasm_vop1, VOP1_SDST_LIT: _disasm_vop1, VOP1_LIT: _disasm_vop1,
|
||||
VOP2: _disasm_vop2, VOP2_LIT: _disasm_vop2, VOPC: _disasm_vopc, VOPC_LIT: _disasm_vopc,
|
||||
VOP3: _disasm_vop3, VOP3_SDST: _disasm_vop3, VOP3_SDST_LIT: _disasm_vop3, VOP3_LIT: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOP3SD_LIT: _disasm_vop3sd,
|
||||
VOP3: _disasm_vop3, VOP3_SDST: _disasm_vop3, VOP3_SDST_LIT: _disasm_vop3, VOP3_LIT: _disasm_vop3,
|
||||
VOP3SD: _disasm_vop3sd, VOP3SD_LIT: _disasm_vop3sd,
|
||||
VOPD: _disasm_vopd, VOPD_LIT: _disasm_vopd, VOP3P: _disasm_vop3p, VOP3P_LIT: _disasm_vop3p,
|
||||
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, GLOBAL: _disasm_flat, SCRATCH: _disasm_flat,
|
||||
SOP1: _disasm_sop1, SOP1_LIT: _disasm_sop1, SOP2: _disasm_sop2, SOP2_LIT: _disasm_sop2,
|
||||
|
|
@ -634,7 +665,9 @@ def _disasm_vop3a(inst) -> str:
|
|||
else:
|
||||
regs = inst.canonical_op_regs
|
||||
dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
|
||||
s0 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0)
|
||||
s1 = _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1)
|
||||
s2 = _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
|
||||
dst = _vreg(inst.vdst, dregs) if dregs > 1 else _vreg(inst.vdst)
|
||||
if op_val >= 512:
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{opsel}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{opsel}{cl}{om}"
|
||||
|
|
@ -658,7 +691,9 @@ def _disasm_vop3b(inst) -> str:
|
|||
n = inst.num_srcs() or _num_srcs(inst)
|
||||
regs = inst.canonical_op_regs
|
||||
dregs, r0, r1, r2 = regs['d'], regs['s0'], regs['s1'], regs['s2']
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
|
||||
s0 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0)
|
||||
s1 = _cdna_src(inst, inst.src1, inst.neg&2, n=r1)
|
||||
s2 = _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
|
||||
# CDNA VOP3_SDST uses vdst field for sdst (but vdst adds 256), RDNA uses separate sdst field
|
||||
sdst_val = getattr(inst, 'sdst', None)
|
||||
if sdst_val is None and hasattr(inst, 'vdst'):
|
||||
|
|
@ -680,7 +715,7 @@ def _disasm_cdna_vop3p(inst) -> str:
|
|||
name, n = inst.op_name.lower(), inst.num_srcs() or 2
|
||||
is_mfma = 'mfma' in name or 'smfmac' in name
|
||||
is_accvgpr = 'accvgpr' in name
|
||||
get_src = lambda v, sc: _lit(inst, v) if v == 255 else _fmt_src(v, sc, cdna=True)
|
||||
def get_src(v, sc): return _lit(inst, v) if v == 255 else _fmt_src(v, sc, cdna=True)
|
||||
|
||||
# Handle accvgpr read/write (accumulator register operations)
|
||||
if is_accvgpr:
|
||||
|
|
@ -742,9 +777,12 @@ def _disasm_cdna_vop3p(inst) -> str:
|
|||
src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), _vreg(inst.vdst)
|
||||
opsel_hi = inst.opsel_hi # CDNA VOP3P only has 2 bits for opsel_hi (no opsel_hi2)
|
||||
opsel_hi_default = 3 # CDNA default is 0b11 (2 bits), not 0b111 like RDNA
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
mods = (([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else [])
|
||||
+ ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else [])
|
||||
+ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else [])
|
||||
+ ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []))
|
||||
mod_s = ' ' + ' '.join(mods) if mods else ''
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{mod_s}" if n == 3 else f"{name} {dst}, {src0}, {src1}{mod_s}"
|
||||
|
||||
def _disasm_mubuf(inst) -> str:
|
||||
name = inst.op_name.lower()
|
||||
|
|
@ -903,5 +941,6 @@ DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP1_LIT: _disasm_vop1,
|
|||
CDNA_SOP1: _disasm_sop1, CDNA_SOP1_LIT: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOP2_LIT: _disasm_sop2,
|
||||
CDNA_SOPC: _disasm_sopc, CDNA_SOPC_LIT: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPK_LIT: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
|
||||
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_GLOBAL: _disasm_flat, CDNA_SCRATCH: _disasm_flat,
|
||||
CDNA_VOP3: _disasm_vop3a, CDNA_VOP3_SDST: _disasm_vop3b, CDNA_VOP3SD: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, CDNA_VOP3P_MFMA: _disasm_cdna_vop3p,
|
||||
CDNA_VOP3: _disasm_vop3a, CDNA_VOP3_SDST: _disasm_vop3b, CDNA_VOP3SD: _disasm_vop3b,
|
||||
CDNA_VOP3P: _disasm_cdna_vop3p, CDNA_VOP3P_MFMA: _disasm_cdna_vop3p,
|
||||
CDNA_MUBUF: _disasm_mubuf, CDNA_VOP3PX2: _disasm_vop3px2})
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def get_gpu_target() -> tuple[int, int, int]:
|
|||
"""Get the GPU target as (major, minor, stepping) tuple."""
|
||||
if not USE_HW: return (0, 0, 0)
|
||||
from tinygrad.device import Device
|
||||
return Device["AMD"].target
|
||||
return Device["AMD"].target # type: ignore[attr-defined]
|
||||
|
||||
def skip_unless_gfx(min_major: int, min_minor: int = 0, reason: str = ""):
|
||||
"""Skip test if GPU target is below the minimum required version."""
|
||||
|
|
@ -171,7 +171,7 @@ def run_program_hw(instructions: list, n_lanes: int = 1) -> WaveState:
|
|||
from tinygrad.helpers import flat_mv
|
||||
|
||||
dev = Device["AMD"]
|
||||
compiler = HIPCompiler(dev.arch)
|
||||
compiler = HIPCompiler(dev.arch) # type: ignore[attr-defined]
|
||||
|
||||
prologue, epilogue = get_prologue_epilogue(n_lanes)
|
||||
code = assemble(prologue + instructions + epilogue)
|
||||
|
|
@ -218,7 +218,7 @@ amdhsa.kernels:
|
|||
"""
|
||||
|
||||
lib = compiler.compile(asm_src)
|
||||
prg = AMDProgram(dev, "test", lib)
|
||||
prg = AMDProgram(dev, "test", lib) # type: ignore[arg-type]
|
||||
|
||||
out_gpu = dev.allocator.alloc(OUT_BYTES)
|
||||
assert out_gpu.va_addr % 16 == 0, f"buffer not 16-byte aligned: 0x{out_gpu.va_addr:x}"
|
||||
|
|
@ -276,6 +276,6 @@ def run_program(instructions: list, n_lanes: int = 1, ulp_tolerance: int = 0) ->
|
|||
hw_st = run_program_hw(instructions, n_lanes)
|
||||
diffs = compare_wave_states(emu_st, hw_st, n_lanes, ulp_tolerance=ulp_tolerance)
|
||||
if diffs:
|
||||
raise AssertionError(f"Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
|
||||
raise AssertionError("Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
|
||||
return hw_st
|
||||
return emu_st
|
||||
|
|
|
|||
|
|
@ -601,7 +601,6 @@ class TestDS2AddrStride64(unittest.TestCase):
|
|||
self.assertEqual(st.vgpr[0][6], 0xAAAAAAAA, "new val 0")
|
||||
self.assertEqual(st.vgpr[0][7], 0xBBBBBBBB, "new val 1")
|
||||
|
||||
|
||||
def test_ds_storexchg_rtn_b64(self):
|
||||
"""DS_STOREXCHG_RTN_B64: exchange 64-bit value and return old."""
|
||||
instructions = [
|
||||
|
|
|
|||
|
|
@ -373,7 +373,6 @@ class TestF64Conversions(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f64_f32_pi(self):
|
||||
"""V_CVT_F64_F32 converts f32 pi to f64."""
|
||||
import math
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(3.14159265)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
|
|||
|
|
@ -725,7 +725,7 @@ class TestLaneOps(unittest.TestCase):
|
|||
# v[5] should have the value only in lane 1
|
||||
for lane in range(4):
|
||||
if lane == 1:
|
||||
self.assertEqual(st.vgpr[lane][5], 0x12345678, f"v[5] lane 1 should have 0x12345678")
|
||||
self.assertEqual(st.vgpr[lane][5], 0x12345678, "v[5] lane 1 should have 0x12345678")
|
||||
else:
|
||||
self.assertEqual(st.vgpr[lane][5], 0, f"v[5] lane {lane} should be 0")
|
||||
|
||||
|
|
@ -1082,7 +1082,6 @@ class TestF64Ops(unittest.TestCase):
|
|||
"""Full f64->i64 conversion sequence with negative value."""
|
||||
import struct
|
||||
val = f2i64(-8.0)
|
||||
lit = 0xC1F00000 # high 32 bits of f64 -2^32
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val & 0xffffffff),
|
||||
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
|
||||
|
|
@ -1138,7 +1137,6 @@ class TestF64Ops(unittest.TestCase):
|
|||
# v_fma_f64 v[7:8], v[17:18], v[7:8], v[15:16]
|
||||
# We need to capture the exact input values and verify output matches hardware
|
||||
# v[7:8] before = 0x3f80fdf3_d69db28f (0.008296875941334462)
|
||||
v78 = 0x3f80fdf3d69db28f
|
||||
# For the FMA to produce 0xbf457ef0_ab8c254d, we need v[17:18] and v[15:16]
|
||||
# Let's test with known precision-sensitive values
|
||||
a = 1.0000000001
|
||||
|
|
@ -1395,7 +1393,7 @@ class TestWMMAMore(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f32_16x16x16_f16_basic(self):
|
||||
"""V_WMMA_F32_16X16X16_F16 basic test - verify output is non-zero."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00))
|
||||
for i in range(16, 32):
|
||||
instructions.append(v_mov_b32_e32(v[i], s[0]))
|
||||
|
|
@ -1851,7 +1849,6 @@ class TestMed3(unittest.TestCase):
|
|||
|
||||
def test_v_med3_f32_with_nan(self):
|
||||
"""V_MED3_F32: NaN handling - returns min of non-NaN values."""
|
||||
import math
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x7fc00000), # NaN
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
@ -2490,7 +2487,6 @@ class TestDivScaleF64(unittest.TestCase):
|
|||
independently. This catches the bug where the emulator was setting VCC
|
||||
for all lanes to the same value.
|
||||
"""
|
||||
import math
|
||||
# Use lane-varying input: lane 0 gets 2.0, lane 1 gets 3.0, etc.
|
||||
# All normal values should result in VCC=0 for each lane
|
||||
instructions = [
|
||||
|
|
@ -2721,7 +2717,6 @@ class TestDivScaleFmasF64Integration(unittest.TestCase):
|
|||
This is the exact bug scenario: tan([2.0, 3.0, 4.0]) was failing because
|
||||
VCC from DIV_SCALE was being set incorrectly for all lanes.
|
||||
"""
|
||||
import math
|
||||
# Set up values like tan() would: different values per lane
|
||||
instructions = [
|
||||
# Create per-lane values: 2.0, 3.0, 4.0, 5.0
|
||||
|
|
|
|||
|
|
@ -418,7 +418,7 @@ class TestWMMAF16(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f16_16x16x16_f16_all_ones(self):
|
||||
"""V_WMMA_F16_16X16X16_F16 with all ones produces 16.0 in f16."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[16:23] (8 regs)
|
||||
for i in range(16, 24):
|
||||
|
|
@ -442,7 +442,7 @@ class TestWMMAF16(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f16_16x16x16_f16_with_accumulator(self):
|
||||
"""V_WMMA_F16_16X16X16_F16 with non-zero accumulator."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
instructions.append(s_mov_b32(s[1], 0x4500)) # f16 5.0 in lo bits only
|
||||
# Initialize A matrix in v[16:23] (8 regs)
|
||||
|
|
@ -471,7 +471,7 @@ class TestWMMAF16(unittest.TestCase):
|
|||
Regression test: WMMA was using static register indices instead of dynamic.
|
||||
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
|
||||
"""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[64:71] (8 regs)
|
||||
for i in range(64, 72):
|
||||
|
|
@ -502,7 +502,7 @@ class TestWMMA(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f32_16x16x16_f16_all_ones(self):
|
||||
"""V_WMMA_F32_16X16X16_F16 with all ones produces 16.0."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
for i in range(16, 32):
|
||||
instructions.append(v_mov_b32_e32(v[i], s[0]))
|
||||
|
|
@ -518,7 +518,7 @@ class TestWMMA(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f32_16x16x16_f16_with_accumulator(self):
|
||||
"""V_WMMA_F32_16X16X16_F16 with non-zero accumulator."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00))
|
||||
instructions.append(s_mov_b32(s[1], f2i(5.0)))
|
||||
for i in range(16, 32):
|
||||
|
|
@ -540,7 +540,7 @@ class TestWMMA(unittest.TestCase):
|
|||
causing incorrect results when registers weren't at the default positions.
|
||||
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
|
||||
"""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[64:71]
|
||||
for i in range(64, 72):
|
||||
|
|
@ -569,7 +569,7 @@ class TestWMMABF16(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f32_16x16x16_bf16_all_ones(self):
|
||||
"""V_WMMA_F32_16X16X16_BF16 with all ones produces 16.0."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
# BF16 1.0 = 0x3f80, packed = 0x3f803f80
|
||||
instructions.append(s_mov_b32(s[0], 0x3f803f80))
|
||||
for i in range(16, 32):
|
||||
|
|
@ -586,7 +586,7 @@ class TestWMMABF16(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f32_16x16x16_bf16_with_accumulator(self):
|
||||
"""V_WMMA_F32_16X16X16_BF16 with non-zero accumulator."""
|
||||
instructions = []
|
||||
instructions: list[Inst] = []
|
||||
# BF16 1.0 = 0x3f80, packed = 0x3f803f80
|
||||
instructions.append(s_mov_b32(s[0], 0x3f803f80))
|
||||
instructions.append(s_mov_b32(s[1], f2i(5.0)))
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ VOPD executes two operations simultaneously. Key behavior:
|
|||
- Op Y can use ops 0-18 (includes ADD_NC_U32, LSHLREV, AND)
|
||||
"""
|
||||
import unittest
|
||||
from extra.assembly.amd.test.hw.helpers import run_program, run_program_emu, run_program_hw, compare_wave_states, \
|
||||
v, s, v_mov_b32_e32, s_mov_b32
|
||||
from extra.assembly.amd.test.hw.helpers import run_program, v, v_mov_b32_e32
|
||||
from extra.assembly.amd.autogen.rdna3.ins import VOPD, VOPD_LIT, VOPDOp
|
||||
|
||||
class TestVOPDBasic(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -81,7 +81,9 @@ class RustEmulator:
|
|||
return snap.to_snapshot()
|
||||
|
||||
def free(self):
|
||||
if self.ctx: self.lib.wave_free(self.ctx); self.ctx = None
|
||||
if self.ctx:
|
||||
self.lib.wave_free(self.ctx)
|
||||
self.ctx = None
|
||||
|
||||
class PythonEmulator:
|
||||
def __init__(self):
|
||||
|
|
@ -114,8 +116,8 @@ class PythonEmulator:
|
|||
if pc == 0xFFFFFFFFFFFFFFFF or pc not in self.program: return -1
|
||||
name, fxn, globals_list, _runner = self.program[pc]
|
||||
if fxn is None: return 1 # unsupported instruction
|
||||
buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr,
|
||||
2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr}
|
||||
buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr, # type: ignore[union-attr]
|
||||
2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr} # type: ignore[union-attr]
|
||||
# Direct ctypes call - bypasses HCQ overhead
|
||||
fxn(*[ctypes.c_uint64(buf_addrs[g]) for g in globals_list], ctypes.c_int32(0))
|
||||
return -1 if self.state.pc == 0xFFFFFFFFFFFFFFFF else 0
|
||||
|
|
@ -178,6 +180,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
|||
rust_before = rust.get_snapshot()
|
||||
python_before = python.get_snapshot()
|
||||
|
||||
assert python.program is not None
|
||||
inst_info = python.program.get(python.lib_addr + python_before.pc * 4) # Convert word offset to actual address
|
||||
inst_hex_name = inst_info[0] if inst_info else f"unknown at PC={python_before.pc}"
|
||||
# Decode the instruction to get mnemonic for sync_after checks
|
||||
|
|
@ -188,7 +191,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
|||
inst_bytes = bytes.fromhex(inst_bytes_hex) if inst_bytes_hex else b''
|
||||
decoded = decode_inst(inst_bytes) if inst_bytes else None
|
||||
inst_mnemonic = repr(decoded).split('(')[0] if decoded else ""
|
||||
except:
|
||||
except Exception:
|
||||
inst_mnemonic = ""
|
||||
# For generic instructions, use function name for sync_after check
|
||||
if not inst_mnemonic: inst_mnemonic = inst_hex_name
|
||||
|
|
@ -220,16 +223,18 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
|||
python_diffs = pb.diff(next_pb, n_lanes, "->")
|
||||
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
|
||||
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
|
||||
elif rust_diffs: trace_lines.append(f" python: (no changes)")
|
||||
elif rust_diffs: trace_lines.append(" python: (no changes)")
|
||||
else:
|
||||
# Last traced instruction - compare with current state
|
||||
rust_diffs = rb.diff(rust_before, n_lanes, "->")
|
||||
python_diffs = pb.diff(python_before, n_lanes, "->")
|
||||
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
|
||||
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
|
||||
elif rust_diffs: trace_lines.append(f" python: (no changes)")
|
||||
elif rust_diffs: trace_lines.append(" python: (no changes)")
|
||||
trace_str = "\n".join(trace_lines)
|
||||
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n " + "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}", total_steps
|
||||
msg = f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n "
|
||||
msg += "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}"
|
||||
return False, msg, total_steps
|
||||
|
||||
rust_result = rust.step()
|
||||
python_result = python.step()
|
||||
|
|
@ -239,7 +244,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
|||
if rust_result == 1 and python_result == 0:
|
||||
raise unittest.SkipTest(f"Rust emulator doesn't support instruction: {inst_str}")
|
||||
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
|
||||
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
|
||||
msg = (f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: "
|
||||
f"rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}")
|
||||
return False, msg, total_steps
|
||||
|
||||
# Sync Python state to Rust after instructions with known Rust emulator differences
|
||||
if sync_after:
|
||||
|
|
@ -429,7 +436,8 @@ class TestTinygradKernels(unittest.TestCase):
|
|||
def test_cast(self): self._test_kernel(lambda T: T.empty(32).half().float() + T.empty(32).int().float())
|
||||
|
||||
# Pooling - regression for VCC wave32 mode
|
||||
def test_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
|
||||
def test_pool2d(self):
|
||||
self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
|
||||
|
||||
# Convolution
|
||||
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 2, 8, 8).conv2d(T.empty(2, 2, 3, 3)), max_steps=50000)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp:
|
|||
class TestCustomKernel(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
a = Tensor.full((16, 16), 1.).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].renderer.arch))[0]
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
|
||||
ei = a.schedule()[-1].lower()
|
||||
self.assertEqual(ei.prg.estimates.ops, a.numel())
|
||||
self.assertEqual(ei.prg.estimates.mem, a.nbytes()*2)
|
||||
|
|
@ -68,7 +68,7 @@ class TestCustomKernel(unittest.TestCase):
|
|||
def test_variable(self):
|
||||
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
|
||||
a = Tensor.zeros_like(b).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].renderer.arch))[0]
|
||||
a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
|
||||
ei = a.schedule()[-1].lower()
|
||||
for i in range(4):
|
||||
ei.run({"var":i})
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ from tinygrad.uop.ops import UOp, Ops
|
|||
from extra.assembly.amd.emu import parse_pcode
|
||||
from extra.assembly.amd.pcode import parse_expr
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
|
||||
from extra.assembly.amd.autogen.rdna3.enum import VOP1Op, VOP2Op, VOP3Op, SOP1Op, SOP2Op, DSOp
|
||||
from extra.assembly.amd.autogen.rdna3.enum import VOP1Op, VOP2Op, SOP2Op, DSOp
|
||||
|
||||
def _srcs():
|
||||
"""Create minimal source variables for pcode parsing."""
|
||||
u32 = lambda v=0: UOp.const(dtypes.uint32, v)
|
||||
def u32(v=0): return UOp.const(dtypes.uint32, v)
|
||||
return {'S0': u32(), 'S1': u32(), 'S2': u32(), 'SCC': u32(), 'VCC': UOp.const(dtypes.uint64, 0), 'laneId': u32()}
|
||||
|
||||
class TestBasicParsing(unittest.TestCase):
|
||||
|
|
@ -90,16 +90,16 @@ class TestParseExpr(unittest.TestCase):
|
|||
|
||||
def test_variable_lookup(self):
|
||||
"""Test variable lookup in parse_expr."""
|
||||
vars = {'x': UOp.const(dtypes.uint32, 42)}
|
||||
result = parse_expr('x', vars)
|
||||
vrs = {'x': UOp.const(dtypes.uint32, 42)}
|
||||
result = parse_expr('x', vrs)
|
||||
self.assertEqual(result.arg, 42)
|
||||
|
||||
def test_binary_ops(self):
|
||||
"""Test parsing binary operations."""
|
||||
vars = {'a': UOp.const(dtypes.uint32, 10), 'b': UOp.const(dtypes.uint32, 5)}
|
||||
vrs = {'a': UOp.const(dtypes.uint32, 10), 'b': UOp.const(dtypes.uint32, 5)}
|
||||
|
||||
# Addition
|
||||
result = parse_expr('a + b', vars)
|
||||
result = parse_expr('a + b', vrs)
|
||||
self.assertEqual(result.op, Ops.ADD)
|
||||
|
||||
# Subtraction with constant folding
|
||||
|
|
@ -109,8 +109,8 @@ class TestParseExpr(unittest.TestCase):
|
|||
|
||||
def test_ternary(self):
|
||||
"""Test parsing ternary expressions."""
|
||||
vars = {'cond': UOp.const(dtypes.bool, True), 'a': UOp.const(dtypes.uint32, 1), 'b': UOp.const(dtypes.uint32, 0)}
|
||||
result = parse_expr('cond ? a : b', vars)
|
||||
vrs = {'cond': UOp.const(dtypes.bool, True), 'a': UOp.const(dtypes.uint32, 1), 'b': UOp.const(dtypes.uint32, 0)}
|
||||
result = parse_expr('cond ? a : b', vrs)
|
||||
self.assertEqual(result.op, Ops.WHERE)
|
||||
|
||||
class TestForLoopParsing(unittest.TestCase):
|
||||
|
|
@ -120,13 +120,14 @@ class TestForLoopParsing(unittest.TestCase):
|
|||
"""Verify CLZ pcode is available."""
|
||||
pcode = PCODE.get(VOP1Op.V_CLZ_I32_U32_E32)
|
||||
self.assertIsNotNone(pcode)
|
||||
assert pcode is not None
|
||||
self.assertIn('for', pcode.lower())
|
||||
|
||||
def test_clz_parsing(self):
|
||||
"""Test CLZ pcode parsing produces correct structure."""
|
||||
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
|
||||
S0 = UOp.const(dtypes.uint32, 0xFFFFFFFF) # All ones - CLZ should be 0
|
||||
vars, assigns = parse_pcode(pcode, {'S0': S0})
|
||||
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
|
||||
|
||||
self.assertEqual(len(assigns), 1)
|
||||
dest, val = assigns[0]
|
||||
|
|
@ -138,7 +139,7 @@ class TestForLoopParsing(unittest.TestCase):
|
|||
"""Test CLZ with input 0 - should return -1."""
|
||||
pcode = PCODE[VOP1Op.V_CLZ_I32_U32_E32]
|
||||
S0 = UOp.const(dtypes.uint32, 0)
|
||||
vars, assigns = parse_pcode(pcode, {'S0': S0})
|
||||
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
|
||||
|
||||
# Check that the innermost value (default) is -1 (may be wrapped in CAST)
|
||||
val = assigns[0][1]
|
||||
|
|
@ -157,7 +158,7 @@ class TestForLoopParsing(unittest.TestCase):
|
|||
self.skipTest("V_CTZ_I32_B32_E32 pcode not available")
|
||||
|
||||
S0 = UOp.const(dtypes.uint32, 1) # LSB set - CTZ should be 0
|
||||
vars, assigns = parse_pcode(pcode, {'S0': S0})
|
||||
_vrs, assigns = parse_pcode(pcode, {'S0': S0})
|
||||
self.assertEqual(len(assigns), 1)
|
||||
|
||||
class TestDSPcodePatterns(unittest.TestCase):
|
||||
|
|
@ -167,6 +168,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
"""Test DS_LOAD_B32 pcode is parseable."""
|
||||
pcode = PCODE.get(DSOp.DS_LOAD_B32)
|
||||
self.assertIsNotNone(pcode)
|
||||
assert pcode is not None
|
||||
self.assertIn('RETURN_DATA', pcode)
|
||||
self.assertIn('MEM[', pcode)
|
||||
|
||||
|
|
@ -174,6 +176,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
"""Test DS_STORE_B32 pcode is parseable."""
|
||||
pcode = PCODE.get(DSOp.DS_STORE_B32)
|
||||
self.assertIsNotNone(pcode)
|
||||
assert pcode is not None
|
||||
self.assertIn('MEM[', pcode)
|
||||
self.assertIn('DATA', pcode)
|
||||
|
||||
|
|
@ -182,9 +185,9 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
# Create a mock LDS buffer
|
||||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||||
addr = UOp.const(dtypes.uint32, 0)
|
||||
vars = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
|
||||
vrs = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
|
||||
|
||||
result = parse_expr('MEM[ADDR + OFFSET].b32', vars)
|
||||
result = parse_expr('MEM[ADDR + OFFSET].b32', vrs)
|
||||
# Should be an INDEX operation into LDS
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
|
|
@ -192,6 +195,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
"""Test DS_STORE_2ADDR_B32 pcode parsing produces MEM writes."""
|
||||
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
|
||||
self.assertIsNotNone(pcode)
|
||||
assert pcode is not None
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint32, 0),
|
||||
'OFFSET0': UOp.const(dtypes.uint32, 0),
|
||||
|
|
@ -207,12 +211,13 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
self.assertTrue(dest.startswith('MEM['))
|
||||
# val should be (addr, write_val) tuple
|
||||
self.assertIsInstance(val, tuple)
|
||||
self.assertEqual(len(val), 2)
|
||||
self.assertEqual(len(val), 2) # type: ignore[arg-type]
|
||||
|
||||
def test_ds_load_2addr_b32_parsing(self):
|
||||
"""Test DS_LOAD_2ADDR_B32 pcode parsing produces RETURN_DATA assignments."""
|
||||
pcode = PCODE.get(DSOp.DS_LOAD_2ADDR_B32)
|
||||
self.assertIsNotNone(pcode)
|
||||
assert pcode is not None
|
||||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint32, 0),
|
||||
|
|
@ -230,6 +235,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
def test_ds_store_address_calculation(self):
|
||||
"""Test DS_STORE_2ADDR_B32 calculates correct addresses (offset * 4)."""
|
||||
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
|
||||
assert pcode is not None
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint32, 100),
|
||||
'OFFSET0': UOp.const(dtypes.uint32, 2),
|
||||
|
|
@ -240,14 +246,14 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
# Check addresses: 100 + 2*4 = 108, 100 + 5*4 = 120
|
||||
addr0, _ = assigns[0][1]
|
||||
addr1, _ = assigns[1][1]
|
||||
self.assertEqual(addr0.simplify().arg, 108)
|
||||
self.assertEqual(addr1.simplify().arg, 120)
|
||||
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
|
||||
self.assertEqual(assigns[0][1][0].simplify().arg, 108) # type: ignore[index]
|
||||
self.assertEqual(assigns[1][1][0].simplify().arg, 120) # type: ignore[index]
|
||||
|
||||
def test_ds_store_data_values(self):
|
||||
"""Test DS_STORE_2ADDR_B32 uses correct data values."""
|
||||
pcode = PCODE.get(DSOp.DS_STORE_2ADDR_B32)
|
||||
assert pcode is not None
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint32, 0),
|
||||
'OFFSET0': UOp.const(dtypes.uint32, 0),
|
||||
|
|
@ -257,11 +263,10 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
}
|
||||
srcs['laneId'] = UOp.const(dtypes.uint32, 0)
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
_, val0 = assigns[0][1]
|
||||
_, val1 = assigns[1][1]
|
||||
# assigns[i][1] is (addr, val) tuple for MEM writes; mypy sees UOp
|
||||
# DATA[31:0] should preserve the value
|
||||
self.assertEqual(val0.simplify().arg, 0xAAAAAAAA)
|
||||
self.assertEqual(val1.simplify().arg, 0xBBBBBBBB)
|
||||
self.assertEqual(assigns[0][1][1].simplify().arg, 0xAAAAAAAA) # type: ignore[index]
|
||||
self.assertEqual(assigns[1][1][1].simplify().arg, 0xBBBBBBBB) # type: ignore[index]
|
||||
|
||||
class TestConditionalParsing(unittest.TestCase):
|
||||
"""Test conditional (if/elsif/else) pcode parsing."""
|
||||
|
|
@ -273,7 +278,7 @@ class TestConditionalParsing(unittest.TestCase):
|
|||
s0 = UOp.const(dtypes.uint32, 10)
|
||||
s1 = UOp.const(dtypes.uint32, 20)
|
||||
scc = UOp.const(dtypes.uint32, 1)
|
||||
vars, assigns = parse_pcode(pcode, {'S0': s0, 'S1': s1, 'SCC': scc})
|
||||
_vrs, assigns = parse_pcode(pcode, {'S0': s0, 'S1': s1, 'SCC': scc})
|
||||
self.assertEqual(len(assigns), 1)
|
||||
dest, val = assigns[0]
|
||||
self.assertTrue(dest.startswith('D0'))
|
||||
|
|
@ -294,7 +299,8 @@ class TestAllPcode(unittest.TestCase):
|
|||
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
|
||||
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
|
||||
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
|
||||
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(), 'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
|
||||
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(),
|
||||
'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
|
||||
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
|
||||
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}
|
||||
|
||||
|
|
@ -306,7 +312,9 @@ class TestAllPcode(unittest.TestCase):
|
|||
try:
|
||||
parse_pcode(pcode, srcs)
|
||||
passed += 1
|
||||
except RuntimeError as e: skipped += 1; errors[str(e)].append(op.name)
|
||||
except RuntimeError as e:
|
||||
skipped += 1
|
||||
errors[str(e)].append(op.name)
|
||||
except Exception as e: self.fail(f"[{arch}] {op.name}: {e}\nPcode: {pcode[:200]}")
|
||||
total = len(pcode_dict)
|
||||
pct = 100 * passed / total
|
||||
|
|
|
|||
|
|
@ -127,15 +127,17 @@ def _make_test(f: str, arch: str, test_type: str):
|
|||
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")
|
||||
elif test_type == "repr":
|
||||
# Test that eval(repr(inst)) reproduces the instruction
|
||||
if arch == "rdna3": import extra.assembly.amd.autogen.rdna3.ins as ins
|
||||
elif arch == "rdna4": import extra.assembly.amd.autogen.rdna4.ins as ins
|
||||
elif arch == "cdna": import extra.assembly.amd.autogen.cdna.ins as ins
|
||||
if arch == "rdna3": import extra.assembly.amd.autogen.rdna3.ins as ins # type: ignore[no-redef]
|
||||
elif arch == "rdna4": import extra.assembly.amd.autogen.rdna4.ins as ins # type: ignore[no-redef]
|
||||
elif arch == "cdna": import extra.assembly.amd.autogen.cdna.ins as ins # type: ignore[no-redef]
|
||||
ns = {k: getattr(ins, k) for k in dir(ins) if not k.startswith('_')}
|
||||
passed, skipped = 0, 0
|
||||
for _, data in tests:
|
||||
try:
|
||||
decoded = detect_format(data, arch).from_bytes(data)
|
||||
if decoded.to_bytes()[:len(data)] != data: skipped += 1; continue # skip if binary roundtrip fails
|
||||
if decoded.to_bytes()[:len(data)] != data:
|
||||
skipped += 1
|
||||
continue # skip if binary roundtrip fails
|
||||
r = repr(decoded)
|
||||
try:
|
||||
decoded2 = eval(r, ns) # noqa: S307
|
||||
|
|
@ -153,7 +155,7 @@ def _make_test(f: str, arch: str, test_type: str):
|
|||
enc = decoded.to_bytes()[:len(data)]
|
||||
# Skip if roundtrip fails, disasm fails, or op_name is missing (disasm starts with space)
|
||||
if enc == data and (d := disasm(decoded)) and not d.startswith(' '): to_test.append((enc, d))
|
||||
except: pass
|
||||
except Exception: pass
|
||||
skipped = len(tests) - len(to_test)
|
||||
print(f"{name}: {len(to_test)} passed, {skipped} skipped")
|
||||
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ from extra.assembly.amd.generate import extract_pdf_text, extract_pcode, parse_x
|
|||
EXPECTED_PAGES = {"rdna3": 655, "rdna4": 711, "cdna": 610}
|
||||
|
||||
class TestPcodePDF(unittest.TestCase):
|
||||
pages: dict
|
||||
enums: dict
|
||||
pcode: dict
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.pages = {arch: extract_pdf_text(cfg["pdf"]) for arch, cfg in ARCHS.items()}
|
||||
|
|
@ -33,7 +37,8 @@ class TestPcodePDF(unittest.TestCase):
|
|||
'tmp = MEM[ADDR].u64;\nsrc = DATA.u64;\nMEM[ADDR].u64 = src >= tmp ? src : tmp;\nRETURN_DATA.u64 = tmp')
|
||||
# GLOBAL_STORE_B128: should have 4 MEM stores (not truncated)
|
||||
self.assertEqual(pcode[('GLOBAL_STORE_B128', 29)],
|
||||
'MEM[ADDR].b32 = VDATA[31 : 0];\nMEM[ADDR + 4U].b32 = VDATA[63 : 32];\nMEM[ADDR + 8U].b32 = VDATA[95 : 64];\nMEM[ADDR + 12U].b32 = VDATA[127 : 96]')
|
||||
'MEM[ADDR].b32 = VDATA[31 : 0];\nMEM[ADDR + 4U].b32 = VDATA[63 : 32];\n'
|
||||
'MEM[ADDR + 8U].b32 = VDATA[95 : 64];\nMEM[ADDR + 12U].b32 = VDATA[127 : 96]')
|
||||
# S_CMOVK_I32: should have full if/endif block
|
||||
self.assertEqual(pcode[('S_CMOVK_I32', 2)],
|
||||
"if SCC then\nD0.i32 = 32'I(signext(SIMM16.i16))\nendif")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec
|
|||
from tinygrad.dtype import dtypes
|
||||
|
||||
class TestRDNA4Emu(unittest.TestCase):
|
||||
def _run(self, insts: list, sgprs: dict[int, int] = None, vgprs: dict[tuple[int, int], int] = None) -> WaveState:
|
||||
def _run(self, insts: list, sgprs: dict[int, int] | None = None, vgprs: dict[tuple[int, int], int] | None = None) -> WaveState:
|
||||
"""Run instructions and return final WaveState."""
|
||||
# Add S_ENDPGM if not present
|
||||
if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts):
|
||||
|
|
@ -22,10 +22,8 @@ class TestRDNA4Emu(unittest.TestCase):
|
|||
# Setup wave state
|
||||
st = WaveState(n_lanes=1)
|
||||
st.pc = code_addr
|
||||
if sgprs:
|
||||
for idx, val in sgprs.items(): st._write_sgpr(idx, val)
|
||||
if vgprs:
|
||||
for (reg, lane), val in vgprs.items(): st._write_vgpr(reg, lane, val)
|
||||
for idx, val in (sgprs or {}).items(): st._write_sgpr(idx, val)
|
||||
for (reg, lane), val in (vgprs or {}).items(): st._write_vgpr(reg, lane, val)
|
||||
|
||||
# Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access)
|
||||
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd import decode_inst, detect_format
|
||||
from extra.assembly.amd import detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
|
||||
|
|
@ -100,11 +99,6 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
|||
while offset < len(code):
|
||||
remaining = code[offset:]
|
||||
fmt = detect_format(remaining, arch)
|
||||
if fmt is None:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
|
||||
base_size = fmt._size()
|
||||
if len(remaining) < base_size:
|
||||
break
|
||||
|
|
@ -190,8 +184,8 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
|||
print(f"[{arch}] decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"[{arch}] asm via llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"[{arch}] disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
|
||||
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
|
||||
self.assertEqual(decode_failed, 0, "Decode failures:\n" + "\n".join(decode_failures[:20]))
|
||||
self.assertEqual(asm_failed, 0, "Asm failures:\n" + "\n".join(asm_failures[:20]))
|
||||
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
|
||||
|
||||
# Basic unary ops
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ from tinygrad.runtime.support.elf import elf_loader
|
|||
from extra.assembly.amd import decode_inst
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOPP
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
|
||||
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
|
||||
ALUEXEC, VMEMEXEC, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, InstOp, InstOpRDNA4, print_packets)
|
||||
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
|
||||
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4,
|
||||
InstOp, InstOpRDNA4, print_packets)
|
||||
from extra.assembly.amd.test.helpers import TARGET_TO_ARCH
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
|
||||
|
|
@ -32,18 +33,18 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
|
|||
assert text is not None, "no .text section found"
|
||||
text_off, text_size = text.header.sh_addr, text.header.sh_size
|
||||
|
||||
blob_iter, current_blob = iter(blobs), [None]
|
||||
blob_iter, current_blob = iter(blobs), [None] # type: ignore[var-annotated]
|
||||
occupancy_records: list[tuple[int, int, int, int, bool]] = [] # (wave_id, simd, cu, time, is_start)
|
||||
wave_insts: list[list[tuple[int, int]]] = [] # per-wave list of (time, stall)
|
||||
|
||||
@rocprof.rocprof_trace_decoder_se_data_callback_t
|
||||
def copy_cb(buf, buf_size, _):
|
||||
def copy_cb(buf, buf_size, _): # type: ignore[no-untyped-def]
|
||||
blob = next(blob_iter, None)
|
||||
if blob is None: return 0
|
||||
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob)
|
||||
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte))
|
||||
buf_size[0] = len(current_blob[0])
|
||||
return len(current_blob[0])
|
||||
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob) # type: ignore[call-overload]
|
||||
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte)) # type: ignore[arg-type]
|
||||
buf_size[0] = len(current_blob[0]) # type: ignore[arg-type]
|
||||
return len(current_blob[0]) # type: ignore[arg-type]
|
||||
|
||||
@rocprof.rocprof_trace_decoder_trace_callback_t
|
||||
def trace_cb(record_type, events_ptr, n, _):
|
||||
|
|
@ -94,6 +95,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
|
|||
|
||||
class SQTTExamplesTestBase(unittest.TestCase):
|
||||
target: str
|
||||
examples: dict
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
@ -115,7 +117,9 @@ class SQTTExamplesTestBase(unittest.TestCase):
|
|||
for i, event in enumerate(events):
|
||||
with self.subTest(example=name, event=i):
|
||||
packets = list(decode(event.blob))
|
||||
if DEBUG >= 2: print(f"\n=== {name} event {i} ==="); print_packets(packets)
|
||||
if DEBUG >= 2:
|
||||
print(f"\n=== {name} event {i} ===")
|
||||
print_packets(packets)
|
||||
self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}")
|
||||
self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}")
|
||||
|
||||
|
|
|
|||
|
|
@ -94,12 +94,13 @@ def extract_cdna_packet_sizes():
|
|||
rw_base, rw_offset = _find_segment('rw-p')
|
||||
if not (head := ctypes.c_void_p.from_address(rw_base + (0x2d4f0 - rw_offset)).value if rw_base else None): return None
|
||||
|
||||
pkt_sizes, node, seen = {}, head, set()
|
||||
pkt_sizes: dict[int, int] = {}
|
||||
node, seen = head, set()
|
||||
while node and node not in seen and len(pkt_sizes) < 20:
|
||||
seen.add(node)
|
||||
key, val = ctypes.c_uint32.from_address(node + 8).value, ctypes.c_uint32.from_address(node + 12).value
|
||||
if key < 16 and val in (0x10, 0x20, 0x30, 0x40): pkt_sizes[key] = {0x10: 2, 0x20: 4, 0x30: 6, 0x40: 8}[val]
|
||||
node = ctypes.c_void_p.from_address(node).value
|
||||
node = ctypes.c_void_p.from_address(node).value # type: ignore[assignment]
|
||||
return pkt_sizes if len(pkt_sizes) == 16 else None
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -127,14 +128,14 @@ class TestSQTTMatchesBinary(unittest.TestCase):
|
|||
for pkt_fmt, pkt_cls in PACKET_TYPES_CDNA.items():
|
||||
with self.subTest(packet=pkt_cls.__name__):
|
||||
self.assertEqual(pkt_cls.encoding.default, pkt_fmt)
|
||||
self.assertEqual(CDNA_PKT_SIZES[pkt_fmt] * 2, pkt_cls._size_nibbles)
|
||||
self.assertEqual(CDNA_PKT_SIZES[pkt_fmt] * 2, pkt_cls._size_nibbles) # type: ignore[attr-defined]
|
||||
|
||||
def _test_bit_counts(self, layout: int):
|
||||
if not (tables := extract_bit_tables()): self.skipTest("rocprof-trace-decoder not installed")
|
||||
from extra.assembly.amd.sqtt import PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4
|
||||
for type_id, pkt_cls in {3: PACKET_TYPES_RDNA3, 4: PACKET_TYPES_RDNA4}[layout].items():
|
||||
with self.subTest(packet=pkt_cls.__name__):
|
||||
self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id])
|
||||
self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id]) # type: ignore[attr-defined]
|
||||
|
||||
def _test_encodings(self, layout: int):
|
||||
if not (encodings := extract_packet_encodings()): self.skipTest("rocprof-trace-decoder not installed")
|
||||
|
|
@ -164,14 +165,16 @@ if __name__ == "__main__":
|
|||
|
||||
print("L2:", tables[0], "\nL3:", tables[1], "\nL4:", tables[2])
|
||||
if encodings and tables:
|
||||
print(f"\n{'TypeID':>6} {'Name':>18} {'L2 enc':>12} {'L3 enc':>12} {'L4 enc':>12} {'L2':>4} {'L3':>4} {'L4':>4} {'L2 delta':>12} {'L3 delta':>12} {'L4 delta':>12}")
|
||||
print(f"\n{'TypeID':>6} {'Name':>18} {'L2 enc':>12} {'L3 enc':>12} {'L4 enc':>12}"
|
||||
f" {'L2':>4} {'L3':>4} {'L4':>4} {'L2 delta':>12} {'L3 delta':>12} {'L4 delta':>12}")
|
||||
print("-" * 140)
|
||||
for type_id in sorted(set(encodings[0]) | set(encodings[1]) | set(encodings[2])):
|
||||
name = TYPE_NAMES.get(type_id, f'UNK_{type_id}')
|
||||
bits = [tables[i][type_id] if type_id < len(tables[i]) else 0 for i in range(3)]
|
||||
enc_strs = [f"0x{encodings[i][type_id][0]:02x}/0x{encodings[i][type_id][1]:02x}" if type_id in encodings[i] else "-" for i in range(3)]
|
||||
delta_strs = [f"[{d[1]-1}:{d[0]}]" if (d := deltas[i].get(type_id, (0, 0)))[1] > d[0] else "-" for i in range(3)]
|
||||
print(f"{type_id:6d} {name:>18} {enc_strs[0]:>12} {enc_strs[1]:>12} {enc_strs[2]:>12} {bits[0]:4d} {bits[1]:4d} {bits[2]:4d} {delta_strs[0]:>12} {delta_strs[1]:>12} {delta_strs[2]:>12}")
|
||||
print(f"{type_id:6d} {name:>18} {enc_strs[0]:>12} {enc_strs[1]:>12} {enc_strs[2]:>12}"
|
||||
f" {bits[0]:4d} {bits[1]:4d} {bits[2]:4d} {delta_strs[0]:>12} {delta_strs[1]:>12} {delta_strs[2]:>12}")
|
||||
|
||||
cdna = extract_cdna_packet_sizes()
|
||||
if cdna: print(f"\nCDNA packet sizes: {cdna}")
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
|
|||
|
||||
class TestSQTTMapBase(unittest.TestCase):
|
||||
target: str
|
||||
examples: dict
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ debug = true
|
|||
|
||||
[tool.mypy]
|
||||
warn_unused_configs = true
|
||||
files = ["tinygrad"]
|
||||
files = ["tinygrad", "extra/assembly/amd"]
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
explicit_package_bases = true
|
||||
|
|
@ -142,6 +142,10 @@ strict_equality = true
|
|||
module = "extra.*"
|
||||
follow_imports = "skip"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "extra.assembly.amd.*"
|
||||
follow_imports = "normal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
norecursedirs = [
|
||||
"extra",
|
||||
|
|
@ -180,6 +184,7 @@ exclude = [
|
|||
".git/",
|
||||
"docs/",
|
||||
"extra/",
|
||||
"!extra/assembly/amd/",
|
||||
"test/external/mlperf_resnet",
|
||||
"test/external/mlperf_unet3d",
|
||||
]
|
||||
|
|
@ -245,6 +250,8 @@ select = [
|
|||
"F841",
|
||||
]
|
||||
"tinygrad/runtime/autogen/**/*.py" = ["E501", "F401", "E722", "E731", "F821", "A006", "A002", "F811"]
|
||||
"extra/assembly/amd/autogen/**/*.py" = ["E501"]
|
||||
"extra/assembly/amd/test/**/*.py" = ["F403", "F405"]
|
||||
|
||||
[tool.ruff.format]
|
||||
exclude = ["*"]
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]:
|
|||
name, width = (op_name, 10 if "BARRIER" in op_name else 1)
|
||||
add(name, p, width=width, idx=int("OTHER" in name), info=info)
|
||||
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
|
||||
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info)
|
||||
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info) # type: ignore[union-attr]
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)):
|
||||
name = str(p.src).split('.')[1]
|
||||
if name == "VALU_SALU":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue