mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
62 lines
3.6 KiB
Python
62 lines
3.6 KiB
Python
# Instruction format detection and decoding
|
|
from __future__ import annotations
|
|
from extra.assembly.amd.dsl import Inst
|
|
from extra.assembly.amd.autogen.rdna3.ins import VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP
|
|
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP2 as R4_VOP2, VOP3 as R4_VOP3, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
|
|
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP,
|
|
SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT)
|
|
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
|
|
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
|
|
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP)
|
|
|
|
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
|
"""Check if word matches the encoding pattern of an instruction class."""
|
|
if cls._encoding is None: return False
|
|
bf, val = cls._encoding
|
|
return ((word >> bf.lo) & bf.mask()) == val
|
|
|
|
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
|
|
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
|
|
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
|
|
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
|
|
_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
|
|
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
|
|
_RDNA4_FORMATS_64 = [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3, R4_DS, R4_VBUFFER, R4_SMEM, R4_VEXPORT]
|
|
_RDNA4_FORMATS_32 = [R4_SOP1, R4_SOPC, R4_SOPP, R4_SOPK, R4_VOPC, R4_VOP1, R4_SOP2, R4_VOP2]
|
|
_RDNA4_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
|
|
|
def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
|
|
"""Detect instruction format from machine code bytes."""
|
|
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
|
|
word = int.from_bytes(data[:4], 'little')
|
|
if arch == "cdna":
|
|
if (word >> 30) == 0b11:
|
|
for cls in _CDNA_FORMATS_64:
|
|
if _matches_encoding(word, cls):
|
|
return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls
|
|
raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}")
|
|
for cls in _CDNA_FORMATS_32:
|
|
if _matches_encoding(word, cls): return cls
|
|
raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}")
|
|
if arch == "rdna4":
|
|
if (word >> 30) == 0b11:
|
|
for cls in _RDNA4_FORMATS_64:
|
|
if _matches_encoding(word, cls):
|
|
return R4_VOP3SD if cls is R4_VOP3 and ((word >> 16) & 0x3ff) in _RDNA4_VOP3SD_OPS else cls
|
|
raise ValueError(f"unknown RDNA4 64-bit format word={word:#010x}")
|
|
for cls in _RDNA4_FORMATS_32:
|
|
if _matches_encoding(word, cls): return cls
|
|
raise ValueError(f"unknown RDNA4 32-bit format word={word:#010x}")
|
|
# RDNA3 (default)
|
|
if (word >> 30) == 0b11:
|
|
for cls in _RDNA_FORMATS_64:
|
|
if _matches_encoding(word, cls):
|
|
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
|
|
raise ValueError(f"unknown 64-bit format word={word:#010x}")
|
|
for cls in _RDNA_FORMATS_32:
|
|
if _matches_encoding(word, cls): return cls
|
|
raise ValueError(f"unknown 32-bit format word={word:#010x}")
|
|
|
|
def decode_inst(data: bytes, arch: str = "rdna3") -> Inst:
|
|
"""Decode machine code bytes into an instruction."""
|
|
return detect_format(data, arch).from_bytes(data)
|