mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assembly/amd: factor out pdf generation
This commit is contained in:
parent
04c79505ec
commit
ef5ee0f723
10 changed files with 826 additions and 927 deletions
|
|
@ -1,4 +1,4 @@
|
|||
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by dsl.py - do not edit
|
||||
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by generate.py - do not edit
|
||||
from enum import IntEnum
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pcode --arch cdna
|
||||
# autogenerated by generate.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.generate --arch cdna
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.cdna import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3POp, VOPCOp, VOP3AOp, VOP3BOp
|
||||
|
|
@ -9678,32 +9678,13 @@ def _VOPCOp_V_CMPX_T_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGP
|
|||
# Set the per-lane condition code to 1. Store the result into the EXEC mask and to VCC or a scalar register.
|
||||
# EXEC.u64[laneId] = D0.u64[laneId] = 1'1U;
|
||||
# // D0 = VCC in VOPC encoding.
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# src = DATA.u32;
|
||||
D0 = Reg(d0)
|
||||
VCC = Reg(vcc)
|
||||
EXEC = Reg(exec_mask)
|
||||
tmp = Reg(0)
|
||||
laneId = lane
|
||||
PC = Reg(pc)
|
||||
# --- compiled pseudocode ---
|
||||
EXEC.u64[laneId] = D0.u64[laneId] = 1
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
src = DATA.u32
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
|
|
@ -14177,32 +14158,13 @@ def _VOP3AOp_V_CMPX_T_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
|||
# Set the per-lane condition code to 1. Store the result into the EXEC mask and to VCC or a scalar register.
|
||||
# EXEC.u64[laneId] = D0.u64[laneId] = 1'1U;
|
||||
# // D0 = VCC in VOPC encoding.
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
|
||||
# tmp = MEM[addr].u32;
|
||||
# src = DATA.u32;
|
||||
D0 = Reg(d0)
|
||||
VCC = Reg(vcc)
|
||||
EXEC = Reg(exec_mask)
|
||||
tmp = Reg(0)
|
||||
laneId = lane
|
||||
PC = Reg(pc)
|
||||
# --- compiled pseudocode ---
|
||||
EXEC.u64[laneId] = D0.u64[laneId] = 1
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
|
||||
tmp = Reg(MEM[addr].u32)
|
||||
src = DATA.u32
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# autogenerated from AMD RDNA3.5 ISA PDF by dsl.py - do not edit
|
||||
# autogenerated from AMD RDNA3.5 ISA PDF by generate.py - do not edit
|
||||
from enum import IntEnum
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pcode --arch rdna3
|
||||
# autogenerated by generate.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.generate --arch rdna3
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.rdna3 import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# autogenerated from AMD RDNA4 ISA PDF by dsl.py - do not edit
|
||||
# autogenerated from AMD RDNA4 ISA PDF by generate.py - do not edit
|
||||
from enum import IntEnum
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pcode --arch rdna4
|
||||
# autogenerated by generate.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.generate --arch rdna4
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.rdna4 import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
|
|
|
|||
|
|
@ -352,296 +352,3 @@ class Inst:
|
|||
|
||||
class Inst32(Inst): pass
|
||||
class Inst64(Inst): pass
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CODE GENERATION: generates autogen/__init__.py by parsing AMD ISA PDFs
|
||||
# Supports both RDNA3.5 and CDNA4 instruction set PDFs - auto-detects format
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
PDF_URLS = {
|
||||
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content", # RDNA3.5
|
||||
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
|
||||
"cdna": ["https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf",
|
||||
"https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf"],
|
||||
}
|
||||
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
|
||||
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
|
||||
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
|
||||
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
|
||||
FIELD_ORDER = {
|
||||
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
|
||||
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
|
||||
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
|
||||
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
|
||||
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
|
||||
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
|
||||
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
|
||||
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
|
||||
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
|
||||
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
|
||||
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
|
||||
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
|
||||
def _parse_bits(s: str) -> tuple[int, int] | None:
|
||||
import re
|
||||
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
|
||||
|
||||
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
|
||||
import re
|
||||
fields = []
|
||||
for row in table[1:]:
|
||||
if not row or not row[0]: continue
|
||||
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
|
||||
if not (bits := _parse_bits(bits_str)): continue
|
||||
enc_val, hi, lo = None, bits[0], bits[1]
|
||||
if name == 'ENCODING' and row[2]:
|
||||
# Handle both RDNA3 ('bXX) and CDNA4 (Must be: XX) encoding formats
|
||||
if m := re.search(r"(?:'b|Must be:\s*)([01_]+)", row[2]):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
enc_val = int(enc_bits, 2)
|
||||
declared_width, actual_width = hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
|
||||
fields.append((name, hi, lo, enc_val, ftype))
|
||||
return fields
|
||||
|
||||
def _parse_single_pdf(url: str) -> dict:
|
||||
"""Parse a single PDF and return raw data (formats, enums, src_enum, doc_name, is_cdna)."""
|
||||
import re, pdfplumber
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
pdf = pdfplumber.open(fetch(url))
|
||||
|
||||
# Auto-detect document type from first page
|
||||
first_page_text = pdf.pages[0].extract_text() or ''
|
||||
is_cdna4 = 'CDNA4' in first_page_text or 'CDNA 4' in first_page_text
|
||||
is_cdna3 = 'CDNA3' in first_page_text or 'CDNA 3' in first_page_text or 'MI300' in first_page_text
|
||||
is_cdna = is_cdna3 or is_cdna4
|
||||
is_rdna4 = 'RDNA4' in first_page_text or 'RDNA 4' in first_page_text
|
||||
is_rdna35 = 'RDNA3.5' in first_page_text or 'RDNA 3.5' in first_page_text # Check 3.5 before 3
|
||||
is_rdna3 = not is_rdna35 and ('RDNA3' in first_page_text or 'RDNA 3' in first_page_text)
|
||||
doc_name = "CDNA4" if is_cdna4 else "CDNA3" if is_cdna3 else "RDNA4" if is_rdna4 else "RDNA3.5" if is_rdna35 else "RDNA3" if is_rdna3 else "Unknown"
|
||||
|
||||
# Find the "Microcode Formats" section - search for SOP2 format definition
|
||||
microcode_start = None
|
||||
total_pages = len(pdf.pages)
|
||||
# Search from likely locations (formats are typically 20-95% through the document - RDNA3 has them at ~25%)
|
||||
for i in range(int(total_pages * 0.2), total_pages):
|
||||
text = pdf.pages[i].extract_text() or ''
|
||||
# Look for "X.Y.Z. SOP2" section header or "Chapter X. Microcode Formats"
|
||||
if re.search(r'\d+\.\d+\.\d+\.\s+SOP2\b', text) or re.search(r'Chapter \d+\.\s+Microcode Formats', text):
|
||||
microcode_start = i
|
||||
break
|
||||
if microcode_start is None: microcode_start = int(total_pages * 0.9)
|
||||
|
||||
pages = pdf.pages[microcode_start:microcode_start + 50]
|
||||
page_texts = [p.extract_text() or '' for p in pages]
|
||||
page_tables = [[t.extract() for t in p.find_tables()] for p in pages]
|
||||
full_text = '\n'.join(page_texts)
|
||||
|
||||
# parse SSRC encoding from first page with VCC_LO
|
||||
src_enum = dict(SRC_EXTRAS)
|
||||
for text in page_texts[:10]:
|
||||
if 'SSRC0' in text and 'VCC_LO' in text:
|
||||
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
|
||||
val, name = int(m.group(1)), m.group(2).rstrip('.:')
|
||||
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
|
||||
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
|
||||
break
|
||||
|
||||
# parse opcode tables
|
||||
enums: dict[str, dict[int, str]] = {}
|
||||
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
|
||||
enums[m.group(1) + "Op"] = ops
|
||||
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
|
||||
enums["VOPDOp"] = ops
|
||||
enum_names = set(enums.keys())
|
||||
|
||||
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
|
||||
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
|
||||
def has_header_before_fields(text) -> bool:
|
||||
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
|
||||
|
||||
# find format headers with their page indices
|
||||
format_headers = []
|
||||
for i, text in enumerate(page_texts):
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
|
||||
fmt_name = m.group(1)
|
||||
if is_cdna and fmt_name.isupper() and len(fmt_name) >= 2:
|
||||
format_headers.append((fmt_name, i, m.start()))
|
||||
elif m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < len(page_texts):
|
||||
next_text = page_texts[i + 1].lstrip()
|
||||
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
|
||||
format_headers.append((fmt_name, i, m.start()))
|
||||
|
||||
# parse instruction formats
|
||||
formats: dict[str, list] = {}
|
||||
for fmt_name, page_idx, header_pos in format_headers:
|
||||
if fmt_name in formats: continue
|
||||
text, tables = page_texts[page_idx], page_tables[page_idx]
|
||||
field_pos = text.find('Field Name', header_pos)
|
||||
|
||||
fields = None
|
||||
for offset in range(3):
|
||||
if page_idx + offset >= len(pages): break
|
||||
if offset > 0 and has_header_before_fields(page_texts[page_idx + offset]): break
|
||||
for t in page_tables[page_idx + offset] if offset > 0 or field_pos > header_pos else []:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
|
||||
fields = f
|
||||
break
|
||||
if fields: break
|
||||
|
||||
if not fields and field_pos > header_pos:
|
||||
for t in tables:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)):
|
||||
fields = f
|
||||
break
|
||||
|
||||
if not fields: continue
|
||||
field_names = {f[0] for f in fields}
|
||||
|
||||
for pg_offset in range(1, 3):
|
||||
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
|
||||
for t in page_tables[page_idx + pg_offset]:
|
||||
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
|
||||
for ef in extra:
|
||||
if ef[0] not in field_names:
|
||||
fields.append(ef)
|
||||
field_names.add(ef[0])
|
||||
break
|
||||
formats[fmt_name] = fields
|
||||
|
||||
# fix known PDF errors
|
||||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna}
|
||||
|
||||
def _merge_results(results: list[dict]) -> dict:
|
||||
"""Merge multiple PDF parse results into a superset. Asserts if any conflicts."""
|
||||
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "is_cdna": False}
|
||||
for r in results:
|
||||
merged["doc_names"].append(r["doc_name"])
|
||||
merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"]
|
||||
# Merge src_enum (union, assert no conflicts)
|
||||
for val, name in r["src_enum"].items():
|
||||
if val in merged["src_enum"]:
|
||||
assert merged["src_enum"][val] == name, f"SrcEnum conflict: {val} = {merged['src_enum'][val]} vs {name}"
|
||||
else:
|
||||
merged["src_enum"][val] = name
|
||||
# Merge enums (union of ops per enum, assert no conflicts)
|
||||
for enum_name, ops in r["enums"].items():
|
||||
if enum_name not in merged["enums"]: merged["enums"][enum_name] = {}
|
||||
for val, name in ops.items():
|
||||
if val in merged["enums"][enum_name]:
|
||||
assert merged["enums"][enum_name][val] == name, f"{enum_name} conflict: {val} = {merged['enums'][enum_name][val]} vs {name}"
|
||||
else:
|
||||
merged["enums"][enum_name][val] = name
|
||||
# Merge formats (union of fields, assert no bit position conflicts for same field name)
|
||||
for fmt_name, fields in r["formats"].items():
|
||||
if fmt_name not in merged["formats"]:
|
||||
merged["formats"][fmt_name] = list(fields)
|
||||
else:
|
||||
existing = {f[0]: (f[1], f[2]) for f in merged["formats"][fmt_name]} # name -> (hi, lo)
|
||||
for f in fields:
|
||||
name, hi, lo = f[0], f[1], f[2]
|
||||
if name in existing:
|
||||
assert existing[name] == (hi, lo), f"Format {fmt_name} field {name} conflict: bits {existing[name]} vs ({hi}, {lo})"
|
||||
else:
|
||||
merged["formats"][fmt_name].append(f)
|
||||
return merged
|
||||
|
||||
def generate(output_path: str | None = None, arch: str = "rdna3") -> dict:
|
||||
"""Generate instruction definitions from AMD ISA PDF(s). Returns dict with formats for testing."""
|
||||
urls = PDF_URLS[arch]
|
||||
if isinstance(urls, str): urls = [urls]
|
||||
|
||||
# Parse all PDFs and merge
|
||||
results = [_parse_single_pdf(url) for url in urls]
|
||||
if len(results) == 1:
|
||||
merged = results[0]
|
||||
doc_name = merged["doc_name"]
|
||||
else:
|
||||
merged = _merge_results(results)
|
||||
doc_name = "+".join(merged["doc_names"])
|
||||
|
||||
formats, enums, src_enum = merged["formats"], merged["enums"], merged["src_enum"]
|
||||
|
||||
# generate output
|
||||
def enum_lines(name, items):
|
||||
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
|
||||
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
|
||||
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by dsl.py - do not edit", "from enum import IntEnum",
|
||||
"from typing import Annotated",
|
||||
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"import functools", ""]
|
||||
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
|
||||
# Format-specific field defaults (verified against LLVM test vectors)
|
||||
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
|
||||
lines.append("# instruction formats")
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
order = FIELD_ORDER.get(fmt_name, [])
|
||||
lines.append(f"class {fmt_name}({base}):")
|
||||
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
|
||||
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
|
||||
lines.append(f" encoding = {enc_str}")
|
||||
if defaults := format_defaults.get(fmt_name):
|
||||
lines.append(f" _defaults = {defaults}")
|
||||
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
|
||||
if ftype and ftype.endswith('Op'):
|
||||
ann = f":Annotated[BitField, {ftype}]"
|
||||
else:
|
||||
ann = f":{ftype}" if ftype else ""
|
||||
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
|
||||
lines.append("")
|
||||
lines.append("# instruction helpers")
|
||||
for cls_name, ops in sorted(enums.items()):
|
||||
fmt = cls_name[:-2]
|
||||
for op_val, name in sorted(ops.items()):
|
||||
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
|
||||
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
|
||||
if fmt in ("VOP1", "VOP2", "VOPC"):
|
||||
suffix = "_e32"
|
||||
elif fmt == "VOP3" and op_val < 512:
|
||||
suffix = "_e64"
|
||||
else:
|
||||
suffix = ""
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else:
|
||||
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
skip_exports = {'DPP8', 'DPP16'}
|
||||
src_names = {name for _, name in src_enum.items()}
|
||||
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports]
|
||||
if "NULL" in src_names: lines.append("OFF = NULL\n")
|
||||
|
||||
if output_path is not None:
|
||||
import pathlib
|
||||
pathlib.Path(output_path).write_text('\n'.join(lines))
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Generate instruction definitions from AMD ISA PDF")
|
||||
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3", help="Target architecture (default: rdna3)")
|
||||
args = parser.parse_args()
|
||||
if args.arch == "all":
|
||||
for arch in PDF_URLS.keys():
|
||||
result = generate(f"extra/assembly/amd/autogen/{arch}/__init__.py", arch=arch)
|
||||
print(f"{arch}: generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")
|
||||
else:
|
||||
result = generate(f"extra/assembly/amd/autogen/{args.arch}/__init__.py", arch=args.arch)
|
||||
print(f"generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")
|
||||
|
|
|
|||
772
extra/assembly/amd/generate.py
Normal file
772
extra/assembly/amd/generate.py
Normal file
|
|
@ -0,0 +1,772 @@
|
|||
# PDF parsing and code generation for AMD ISA
|
||||
# Generates both autogen/__init__.py (instruction formats) and gen_pcode.py (pseudocode functions)
|
||||
# Usage: python -m extra.assembly.amd.pdf --arch rdna3
|
||||
import re
|
||||
|
||||
PDF_URLS = {
|
||||
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content", # RDNA3.5
|
||||
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
|
||||
"cdna": ["https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf",
|
||||
"https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf"],
|
||||
}
|
||||
|
||||
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
|
||||
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
|
||||
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
|
||||
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
|
||||
FIELD_ORDER = {
|
||||
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
|
||||
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
|
||||
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
|
||||
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
|
||||
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
|
||||
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
|
||||
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
|
||||
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
|
||||
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
|
||||
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
|
||||
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
|
||||
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# SHARED PDF PARSING INFRASTRUCTURE
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class _LazyPageCache:
|
||||
"""Lazy page text/table extractor with caching to avoid redundant PDF parsing."""
|
||||
__slots__ = ('_pdf', '_offset', '_text_cache', '_table_cache')
|
||||
def __init__(self, pdf, offset: int = 0):
|
||||
self._pdf, self._offset, self._text_cache, self._table_cache = pdf, offset, {}, {}
|
||||
def text(self, idx: int) -> str:
|
||||
if idx not in self._text_cache: self._text_cache[idx] = self._pdf.pages[self._offset + idx].extract_text() or ''
|
||||
return self._text_cache[idx]
|
||||
def tables(self, idx: int) -> list:
|
||||
if idx not in self._table_cache: self._table_cache[idx] = [t.extract() for t in self._pdf.pages[self._offset + idx].find_tables()]
|
||||
return self._table_cache[idx]
|
||||
def texts_range(self, start: int, end: int) -> list[str]: return [self.text(i) for i in range(start, end)]
|
||||
|
||||
def _detect_doc_type(first_page_text: str) -> tuple[bool, str]:
|
||||
"""Detect document type from first page text. Returns (is_cdna, doc_name)."""
|
||||
is_cdna4 = 'CDNA4' in first_page_text or 'CDNA 4' in first_page_text
|
||||
is_cdna3 = 'CDNA3' in first_page_text or 'CDNA 3' in first_page_text or 'MI300' in first_page_text
|
||||
is_cdna = is_cdna3 or is_cdna4
|
||||
is_rdna4 = 'RDNA4' in first_page_text or 'RDNA 4' in first_page_text
|
||||
is_rdna35 = 'RDNA3.5' in first_page_text or 'RDNA 3.5' in first_page_text
|
||||
is_rdna3 = not is_rdna35 and ('RDNA3' in first_page_text or 'RDNA 3' in first_page_text)
|
||||
doc_name = "CDNA4" if is_cdna4 else "CDNA3" if is_cdna3 else "RDNA4" if is_rdna4 else "RDNA3.5" if is_rdna35 else "RDNA3" if is_rdna3 else "Unknown"
|
||||
return is_cdna, doc_name
|
||||
|
||||
def _find_chapter(cache: _LazyPageCache, total_pages: int, pattern: str, sample_pcts: list[float]) -> int | None:
|
||||
"""Find chapter page by sampling at likely positions then searching nearby. Returns page index or None."""
|
||||
for pct in sample_pcts:
|
||||
idx = int(total_pages * pct)
|
||||
if 0 <= idx < total_pages and re.search(pattern, cache.text(idx)): return idx
|
||||
for pct in sample_pcts:
|
||||
base = int(total_pages * pct)
|
||||
for offset in range(-10, 11):
|
||||
idx = base + offset
|
||||
if 0 <= idx < total_pages and re.search(pattern, cache.text(idx)): return idx
|
||||
return None
|
||||
|
||||
def _parse_bits(s: str) -> tuple[int, int] | None:
|
||||
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
|
||||
|
||||
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
|
||||
fields = []
|
||||
for row in table[1:]:
|
||||
if not row or not row[0]: continue
|
||||
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
|
||||
if not (bits := _parse_bits(bits_str)): continue
|
||||
enc_val, hi, lo = None, bits[0], bits[1]
|
||||
if name == 'ENCODING' and row[2]:
|
||||
if m := re.search(r"(?:'b|Must be:\s*)([01_]+)", row[2]):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
enc_val = int(enc_bits, 2)
|
||||
declared_width, actual_width = hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
|
||||
fields.append((name, hi, lo, enc_val, ftype))
|
||||
return fields
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PARSE SINGLE PDF - extracts both format definitions AND pseudocode
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _parse_single_pdf(url: str) -> dict:
|
||||
"""Parse a single PDF and return raw data for both dsl and pcode generation."""
|
||||
import pdfplumber
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
pdf = pdfplumber.open(fetch(url))
|
||||
total_pages = len(pdf.pages)
|
||||
cache = _LazyPageCache(pdf)
|
||||
|
||||
# Auto-detect document type from first page
|
||||
is_cdna, doc_name = _detect_doc_type(cache.text(0))
|
||||
|
||||
# Find chapter positions using sampling
|
||||
instr_pattern = r'Chapter \d+\.\s+Instructions\b'
|
||||
microcode_pattern = r'Chapter \d+\.\s+Microcode Formats'
|
||||
|
||||
def is_microcode_page(text):
|
||||
return (re.search(r'\d+\.\d+\.\d+\.\s+SOP2\s*\n.*Description', text) or
|
||||
re.search(r'Chapter \d+\.\s+Microcode Formats\s*\n.*This section', text))
|
||||
|
||||
# Find microcode section with sampling (RDNA ~23%, CDNA ~93%)
|
||||
microcode_start = None
|
||||
for pct in [0.233, 0.93, 0.24, 0.92, 0.25, 0.22]:
|
||||
sample = int(total_pages * pct)
|
||||
if is_microcode_page(cache.text(sample)):
|
||||
microcode_start = sample
|
||||
while microcode_start > 0 and is_microcode_page(cache.text(microcode_start - 1)):
|
||||
microcode_start -= 1
|
||||
break
|
||||
if microcode_start is None:
|
||||
for i in range(int(total_pages * 0.15), total_pages):
|
||||
if is_microcode_page(cache.text(i)):
|
||||
microcode_start = i
|
||||
break
|
||||
if microcode_start is None: microcode_start = int(total_pages * 0.9)
|
||||
|
||||
# Find instructions chapter for pseudocode extraction
|
||||
if is_cdna:
|
||||
instr_start = _find_chapter(cache, total_pages, instr_pattern, [0.17, 0.18, 0.16, 0.19, 0.15])
|
||||
instr_end = _find_chapter(cache, total_pages, microcode_pattern, [0.93, 0.92, 0.94, 0.91, 0.95])
|
||||
else:
|
||||
instr_start = _find_chapter(cache, total_pages, instr_pattern, [0.30, 0.31, 0.29, 0.32, 0.28])
|
||||
instr_end = None # RDNA: Instructions goes to end
|
||||
if instr_start is None: instr_start = int(total_pages * (0.17 if is_cdna else 0.30))
|
||||
if instr_end is None: instr_end = total_pages
|
||||
|
||||
# ─── Parse format definitions from Microcode Formats chapter ───
|
||||
fmt_cache = _LazyPageCache(pdf, microcode_start)
|
||||
page_count = min(45, total_pages - microcode_start)
|
||||
for idx, text in cache._text_cache.items():
|
||||
if microcode_start <= idx < microcode_start + page_count: fmt_cache._text_cache[idx - microcode_start] = text
|
||||
|
||||
# Parse SSRC encoding
|
||||
src_enum = dict(SRC_EXTRAS)
|
||||
for i in range(2, 12):
|
||||
text = fmt_cache.text(i)
|
||||
if 'SSRC0' in text and 'VCC_LO' in text:
|
||||
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
|
||||
val, name = int(m.group(1)), m.group(2).rstrip('.:')
|
||||
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
|
||||
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
|
||||
break
|
||||
|
||||
# Parse opcode tables
|
||||
full_text = '\n'.join(fmt_cache.texts_range(2, page_count))
|
||||
enums: dict[str, dict[int, str]] = {}
|
||||
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
|
||||
enums[m.group(1) + "Op"] = ops
|
||||
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
|
||||
enums["VOPDOp"] = ops
|
||||
enum_names = set(enums.keys())
|
||||
|
||||
# Parse format field tables
|
||||
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
|
||||
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
|
||||
def has_header_before_fields(text) -> bool:
|
||||
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
|
||||
|
||||
# Find format headers with their page indices (CDNA has formats from page 1, RDNA from page 2)
|
||||
format_headers = []
|
||||
for i in range(1 if is_cdna else 2, page_count):
|
||||
text = fmt_cache.text(i)
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
|
||||
fmt_name = m.group(1)
|
||||
if is_cdna and fmt_name.isupper() and len(fmt_name) >= 2:
|
||||
format_headers.append((fmt_name, i, m.start()))
|
||||
elif m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < page_count:
|
||||
next_text = fmt_cache.text(i + 1).lstrip()
|
||||
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
|
||||
format_headers.append((fmt_name, i, m.start()))
|
||||
|
||||
# Parse instruction formats
|
||||
formats: dict[str, list[tuple]] = {}
|
||||
for fmt_name, page_idx, header_pos in format_headers:
|
||||
if fmt_name in formats: continue
|
||||
text, tables = fmt_cache.text(page_idx), fmt_cache.tables(page_idx)
|
||||
field_pos = text.find('Field Name', header_pos)
|
||||
|
||||
fields = None
|
||||
for offset in range(3):
|
||||
if page_idx + offset >= page_count: break
|
||||
if offset > 0 and has_header_before_fields(fmt_cache.text(page_idx + offset)): break
|
||||
for t in fmt_cache.tables(page_idx + offset) if offset > 0 or field_pos > header_pos else []:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
|
||||
fields = f
|
||||
break
|
||||
if fields: break
|
||||
|
||||
if not fields and field_pos > header_pos:
|
||||
for t in tables:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)):
|
||||
fields = f
|
||||
break
|
||||
|
||||
if not fields: continue
|
||||
field_names = {f[0] for f in fields}
|
||||
|
||||
# Look for continuation tables on subsequent pages
|
||||
for pg_offset in range(1, 3):
|
||||
if page_idx + pg_offset >= page_count or has_header_before_fields(fmt_cache.text(page_idx + pg_offset)): break
|
||||
for t in fmt_cache.tables(page_idx + pg_offset):
|
||||
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
|
||||
for ef in extra:
|
||||
if ef[0] not in field_names:
|
||||
fields.append(ef)
|
||||
field_names.add(ef[0])
|
||||
break
|
||||
formats[fmt_name] = fields
|
||||
|
||||
# Fix known PDF errors
|
||||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
|
||||
# ─── Parse pseudocode from Instructions chapter ───
|
||||
instr_text = '\n'.join(cache.text(i) for i in range(instr_start, instr_end))
|
||||
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna,
|
||||
"instr_text": instr_text}
|
||||
|
||||
def _merge_results(results: list[dict]) -> dict:
|
||||
"""Merge results from multiple PDFs (e.g., CDNA3 + CDNA4)."""
|
||||
merged = {"formats": {}, "enums": {}, "src_enum": {}, "doc_names": [], "instr_texts": []}
|
||||
for r in results:
|
||||
merged["doc_names"].append(r["doc_name"])
|
||||
merged["instr_texts"].append(r["instr_text"])
|
||||
for val, name in r["src_enum"].items():
|
||||
if val in merged["src_enum"]:
|
||||
assert merged["src_enum"][val] == name, f"SrcEnum conflict: {val} = {merged['src_enum'][val]} vs {name}"
|
||||
else: merged["src_enum"][val] = name
|
||||
for enum_name, ops in r["enums"].items():
|
||||
if enum_name not in merged["enums"]: merged["enums"][enum_name] = {}
|
||||
for val, name in ops.items():
|
||||
if val in merged["enums"][enum_name]:
|
||||
assert merged["enums"][enum_name][val] == name, f"{enum_name} conflict: {val} = {merged['enums'][enum_name][val]} vs {name}"
|
||||
else: merged["enums"][enum_name][val] = name
|
||||
for fmt_name, fields in r["formats"].items():
|
||||
if fmt_name not in merged["formats"]: merged["formats"][fmt_name] = list(fields)
|
||||
else:
|
||||
existing = {f[0]: (f[1], f[2]) for f in merged["formats"][fmt_name]}
|
||||
for f in fields:
|
||||
name, hi, lo = f[0], f[1], f[2]
|
||||
if name in existing:
|
||||
assert existing[name] == (hi, lo), f"Format {fmt_name} field {name} conflict: bits {existing[name]} vs ({hi}, {lo})"
|
||||
else: merged["formats"][fmt_name].append(f)
|
||||
return merged
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# GENERATE __init__.py (instruction formats and enums)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _generate_dsl(merged: dict, doc_name: str, output_path: str | None = None) -> str:
|
||||
"""Generate instruction definitions code."""
|
||||
formats, enums, src_enum = merged["formats"], merged["enums"], merged["src_enum"]
|
||||
|
||||
def enum_lines(name, items):
|
||||
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
|
||||
def field_key(f):
|
||||
order = FIELD_ORDER.get(fmt_name, [])
|
||||
return order.index(f[0].lower()) if f[0].lower() in order else 1000
|
||||
|
||||
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by generate.py - do not edit", "from enum import IntEnum",
|
||||
"from typing import Annotated",
|
||||
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"import functools", ""]
|
||||
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
|
||||
|
||||
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
|
||||
lines.append("# instruction formats")
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
lines.append(f"class {fmt_name}({base}):")
|
||||
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
|
||||
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
|
||||
lines.append(f" encoding = {enc_str}")
|
||||
if defaults := format_defaults.get(fmt_name):
|
||||
lines.append(f" _defaults = {defaults}")
|
||||
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
|
||||
ann = f":Annotated[BitField, {ftype}]" if ftype and ftype.endswith('Op') else f":{ftype}" if ftype else ""
|
||||
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
|
||||
lines.append("")
|
||||
|
||||
lines.append("# instruction helpers")
|
||||
for cls_name, ops in sorted(enums.items()):
|
||||
fmt = cls_name[:-2]
|
||||
for op_val, name in sorted(ops.items()):
|
||||
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
|
||||
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
|
||||
suffix = "_e32" if fmt in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt == "VOP3" and op_val < 512 else ""
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else:
|
||||
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
|
||||
skip_exports = {'DPP8', 'DPP16'}
|
||||
src_names = {name for _, name in src_enum.items()}
|
||||
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports]
|
||||
if "NULL" in src_names: lines.append("OFF = NULL\n")
|
||||
|
||||
content = '\n'.join(lines)
|
||||
if output_path is not None:
|
||||
import pathlib
|
||||
pathlib.Path(output_path).write_text(content)
|
||||
return content
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PSEUDOCODE COMPILER: pseudocode -> Python
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
"""Expression transform: minimal - just fix syntax differences."""
|
||||
e = e.strip()
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
|
||||
# Pack: { hi, lo } -> _pack(hi, lo)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
|
||||
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
|
||||
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
e = re.sub(r'\bB\(', '(', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||
|
||||
# Constants
|
||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||
|
||||
# Verilog bit slice: [start +: width] -> [start + width - 1 : start]
|
||||
def convert_verilog_slice(m):
|
||||
start, width = m.group(1).strip(), m.group(2).strip()
|
||||
return f'[({start}) + ({width}) - 1 : ({start})]'
|
||||
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
||||
|
||||
# Recursively process bracket contents
|
||||
def process_brackets(s):
|
||||
result, i = [], 0
|
||||
while i < len(s):
|
||||
if s[i] == '[':
|
||||
depth, start = 1, i + 1
|
||||
j = start
|
||||
while j < len(s) and depth > 0:
|
||||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1])
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
e = process_brackets(e)
|
||||
|
||||
# Ternary: a ? b : c -> (b if a else c)
|
||||
while '?' in e:
|
||||
depth, bracket, q = 0, 0, -1
|
||||
for i, c in enumerate(e):
|
||||
if c == '(': depth += 1
|
||||
elif c == ')': depth -= 1
|
||||
elif c == '[': bracket += 1
|
||||
elif c == ']': bracket -= 1
|
||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||
if q < 0: break
|
||||
depth, bracket, col = 0, 0, -1
|
||||
for i in range(q + 1, len(e)):
|
||||
if e[i] == '(': depth += 1
|
||||
elif e[i] == ')': depth -= 1
|
||||
elif e[i] == '[': bracket += 1
|
||||
elif e[i] == ']': bracket -= 1
|
||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||
if col < 0: break
|
||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
line = line.strip()
|
||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||
else:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + "else:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('endif'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# GENERATE gen_pcode.py (compiled pseudocode functions)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s: continue
|
||||
if re.match(r'^\d+ of \d+$', s): continue
|
||||
if re.match(r'^\d+\.\d+\..*Instructions', s): continue
|
||||
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
|
||||
if s.startswith('Notes') or s.startswith('Functional examples'): break
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
continue
|
||||
if s.startswith('if '): depth += 1
|
||||
elif s.startswith('endif'): depth = max(0, depth - 1)
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
|
||||
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
|
||||
)
|
||||
if is_code: result.append(s)
|
||||
return '\n'.join(result) if result else None
|
||||
|
||||
def _parse_pseudocode(instr_texts: list[str], defined_ops: dict, OP_ENUMS: list) -> dict:
|
||||
"""Parse pseudocode from instruction text(s). Returns {enum_cls: {op: pseudocode}}."""
|
||||
instructions: dict = {cls: {} for cls in OP_ENUMS}
|
||||
|
||||
# Process in reverse order so newer PDFs take priority
|
||||
for instr_text in reversed(instr_texts):
|
||||
matches = list(INST_PATTERN.finditer(instr_text))
|
||||
for i, match in enumerate(matches):
|
||||
name, opcode = match.group(1), int(match.group(2))
|
||||
key = (name, opcode)
|
||||
if key not in defined_ops: continue
|
||||
start = match.end()
|
||||
end = matches[i + 1].start() if i + 1 < len(matches) else start + 2000
|
||||
snippet = instr_text[start:end].strip()
|
||||
if (pseudocode := _extract_pseudocode(snippet)):
|
||||
for enum_cls, enum_val in defined_ops[key]:
|
||||
if enum_val not in instructions[enum_cls]:
|
||||
instructions[enum_cls][enum_val] = pseudocode
|
||||
|
||||
return instructions
|
||||
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask', 'S1[i', 'C.i32', 'S[i]', 'in[', 'if n.', 'DST.u32', 'addrd = DST', 'addr = DST']
|
||||
|
||||
def _generate_pcode(instr_texts: list[str], arch: str, output_path: str | None = None) -> tuple[int, int]:
|
||||
"""Generate compiled pseudocode functions. Returns (compiled_count, skipped_count)."""
|
||||
import importlib
|
||||
|
||||
# Load op enums from autogen module
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}")
|
||||
OP_ENUMS = []
|
||||
for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp']:
|
||||
if hasattr(autogen, name): OP_ENUMS.append(getattr(autogen, name))
|
||||
|
||||
# Build defined_ops mapping
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
by_cls = _parse_pseudocode(instr_texts, defined_ops, OP_ENUMS)
|
||||
|
||||
# Report coverage
|
||||
total_found, total_ops = 0, 0
|
||||
for enum_cls in OP_ENUMS:
|
||||
total = sum(1 for op in enum_cls if op.name.startswith(('S_', 'V_')))
|
||||
found = len(by_cls.get(enum_cls, {}))
|
||||
total_found += found
|
||||
total_ops += total
|
||||
print(f"{enum_cls.__name__}: {found}/{total} ({100*found//total if total else 0}%)")
|
||||
print(f"Total: {total_found}/{total_ops} ({100*total_found//total_ops}%)")
|
||||
|
||||
# Generate code
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
lines = [f'''# autogenerated by generate.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.generate --arch {arch}
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.{arch} import {", ".join(enum_names)}
|
||||
from extra.assembly.amd.pcode import *
|
||||
''']
|
||||
|
||||
compiled_count, skipped_count = 0, 0
|
||||
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
pseudocode_dict = by_cls.get(enum_cls, {})
|
||||
if not pseudocode_dict: continue
|
||||
|
||||
fn_entries = []
|
||||
for op, pc in pseudocode_dict.items():
|
||||
if any(p in pc for p in UNSUPPORTED):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
code = compile_pseudocode(pc)
|
||||
# Hardware behavior fixes (see pcode.py for detailed comments)
|
||||
if op.name == 'V_DIV_FMAS_F32':
|
||||
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op.name == 'V_DIV_FMAS_F64':
|
||||
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
if op.name == 'V_DIV_SCALE_F32':
|
||||
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass # denorm check moved to end')
|
||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
if op.name == 'V_DIV_SCALE_F64':
|
||||
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass # denorm check moved to end')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
if op.name == 'V_DIV_FIXUP_F32':
|
||||
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||
if op.name == 'V_DIV_FIXUP_F64':
|
||||
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
if op.name == 'V_TRIG_PREOP_F64':
|
||||
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||
|
||||
# Detect flags for result handling
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
has_pc = 'PC' in pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
|
||||
combined = code + pc
|
||||
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
|
||||
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'), ('PC', 'Reg(pc)')]
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
if 'VCCZ' in combined: used.add('VCC')
|
||||
if 'EXECZ' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'):
|
||||
lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
|
||||
d0_val = "D0._val" if 'D0' in used else "d0"
|
||||
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
|
||||
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
|
||||
if has_sdst:
|
||||
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
elif 'VCC' in used:
|
||||
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx:
|
||||
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
elif 'EXEC' in used:
|
||||
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp:
|
||||
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
if is_64:
|
||||
lines.append(" result['d0_64'] = True")
|
||||
if has_d1:
|
||||
lines.append(" result['d1'] = D1._val & 1")
|
||||
if has_pc:
|
||||
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
|
||||
lines.append(" result['new_pc'] = _pc # absolute byte address")
|
||||
lines.append(" return result")
|
||||
lines.append("")
|
||||
|
||||
fn_entries.append((op, fn_name))
|
||||
compiled_count += 1
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to compile {op.name}: {e}")
|
||||
skipped_count += 1
|
||||
|
||||
if fn_entries:
|
||||
lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name in fn_entries:
|
||||
lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
|
||||
# Add V_WRITELANE_B32 for RDNA
|
||||
if 'VOP3Op' in enum_names:
|
||||
lines.append('''
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
''')
|
||||
|
||||
lines.append('COMPILED_FUNCTIONS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,')
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS')
|
||||
|
||||
if output_path is not None:
|
||||
from pathlib import Path
|
||||
Path(output_path).write_text('\n'.join(lines))
|
||||
return compiled_count, skipped_count
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# MAIN ENTRY POINT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def generate(arch: str = "rdna3", output_dir: str = "extra/assembly/amd/autogen"):
|
||||
"""Generate both __init__.py and gen_pcode.py for the given architecture."""
|
||||
urls = PDF_URLS[arch]
|
||||
if isinstance(urls, str): urls = [urls]
|
||||
|
||||
print(f"Parsing PDF(s) for {arch}...")
|
||||
results = [_parse_single_pdf(url) for url in urls]
|
||||
|
||||
if len(results) == 1:
|
||||
merged = results[0]
|
||||
doc_name = merged["doc_name"]
|
||||
instr_texts = [merged["instr_text"]]
|
||||
else:
|
||||
merged = _merge_results(results)
|
||||
doc_name = "+".join(merged["doc_names"])
|
||||
instr_texts = merged["instr_texts"]
|
||||
|
||||
# Generate __init__.py first (needed for pcode generation)
|
||||
init_path = f"{output_dir}/{arch}/__init__.py"
|
||||
_generate_dsl(merged, doc_name, init_path)
|
||||
print(f"Generated {init_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} opcode enums + {len(merged['formats'])} format classes")
|
||||
|
||||
# Generate gen_pcode.py
|
||||
print("\nCompiling pseudocode functions...")
|
||||
pcode_path = f"{output_dir}/{arch}/gen_pcode.py"
|
||||
compiled, skipped = _generate_pcode(instr_texts, arch, pcode_path)
|
||||
print(f"Generated {pcode_path}: {compiled} compiled, {skipped} skipped")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse, subprocess, sys
|
||||
parser = argparse.ArgumentParser(description="Generate AMD ISA definitions from PDF")
|
||||
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3", help="Target architecture")
|
||||
args = parser.parse_args()
|
||||
if args.arch == "all":
|
||||
procs = [subprocess.Popen([sys.executable, "-m", "extra.assembly.amd.generate", "--arch", arch]) for arch in PDF_URLS.keys()]
|
||||
for p in procs: p.wait()
|
||||
else:
|
||||
generate(arch=args.arch)
|
||||
|
|
@ -295,7 +295,7 @@ def signext_from_bit(val, bit):
|
|||
|
||||
__all__ = [
|
||||
# Classes
|
||||
'Reg', 'SliceProxy', 'TypedView', 'ExecContext', 'compile_pseudocode',
|
||||
'Reg', 'SliceProxy', 'TypedView',
|
||||
# Pack functions
|
||||
'_pack', '_pack32', 'pack', 'pack32',
|
||||
# Constants
|
||||
|
|
@ -623,588 +623,3 @@ class Reg:
|
|||
def __ge__(s, o): return s._val >= int(o)
|
||||
def __eq__(s, o): return s._val == int(o)
|
||||
def __ne__(s, o): return s._val != int(o)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
# Join continuation lines (lines ending with || or && or open paren)
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
line = line.strip()
|
||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||
else:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
|
||||
# Control flow - only need pass before outdent (endif/endfor/else/elsif)
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + "else:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('endif'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
# Handle tuple unpacking: { D1.u1, D0.u64 } = expr
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
# Compound assignment
|
||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
# CLZ/CTZ pattern: assignment of loop var to tmp/D0.i32 in first-match loop needs break
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
# If we ended with a control statement that needs a body, add pass
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
"""Expression transform: minimal - just fix syntax differences."""
|
||||
e = e.strip()
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
|
||||
# Pack: { hi, lo } -> _pack(hi, lo)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
|
||||
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
|
||||
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
e = re.sub(r'\bB\(', '(', e) # Bare B( without digit prefix
|
||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||
# Remove redundant type suffix after lane access: VCC.u64[laneId].u64 -> VCC.u64[laneId]
|
||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||
|
||||
# Constants - INF is defined as an object supporting .f32/.f64 access
|
||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||
|
||||
# Verilog bit slice syntax: [start +: width] -> extract width bits starting at start
|
||||
# Convert to Python slice: [start + width - 1 : start]
|
||||
def convert_verilog_slice(m):
|
||||
start, width = m.group(1).strip(), m.group(2).strip()
|
||||
# Convert to high:low slice format
|
||||
return f'[({start}) + ({width}) - 1 : ({start})]'
|
||||
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
||||
|
||||
# Recursively process bracket contents to handle nested ternaries like S1.u32[x ? a : b]
|
||||
def process_brackets(s):
|
||||
result, i = [], 0
|
||||
while i < len(s):
|
||||
if s[i] == '[':
|
||||
# Find matching ]
|
||||
depth, start = 1, i + 1
|
||||
j = start
|
||||
while j < len(s) and depth > 0:
|
||||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1]) # Recursively process bracket content
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
e = process_brackets(e)
|
||||
|
||||
# Ternary: a ? b : c -> (b if a else c)
|
||||
while '?' in e:
|
||||
depth, bracket, q = 0, 0, -1
|
||||
for i, c in enumerate(e):
|
||||
if c == '(': depth += 1
|
||||
elif c == ')': depth -= 1
|
||||
elif c == '[': bracket += 1
|
||||
elif c == ']': bracket -= 1
|
||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||
if q < 0: break
|
||||
depth, bracket, col = 0, 0, -1
|
||||
for i in range(q + 1, len(e)):
|
||||
if e[i] == '(': depth += 1
|
||||
elif e[i] == ')': depth -= 1
|
||||
elif e[i] == '[': bracket += 1
|
||||
elif e[i] == ']': bracket -= 1
|
||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||
if col < 0: break
|
||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# EXECUTION CONTEXT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class ExecContext:
|
||||
"""Context for running compiled pseudocode."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
|
||||
self.lane, self.laneId, self.literal = lane, lane, literal
|
||||
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
|
||||
self.VGPR = vgprs if vgprs is not None else {}
|
||||
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
|
||||
|
||||
def run(self, code: str):
|
||||
"""Execute compiled code."""
|
||||
# Start with module globals (helpers, aliases), then add instance-specific bindings
|
||||
ns = dict(globals())
|
||||
ns.update({
|
||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
|
||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
|
||||
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||
})
|
||||
exec(code, ns)
|
||||
# Sync rebinds: if register was reassigned to new Reg or value, copy it back
|
||||
def _sync(ctx_reg, ns_val):
|
||||
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
|
||||
else: ctx_reg._val = int(ns_val) & MASK64
|
||||
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
|
||||
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
|
||||
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
|
||||
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
|
||||
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
|
||||
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
|
||||
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
|
||||
|
||||
def result(self) -> dict:
|
||||
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PDF EXTRACTION AND CODE GENERATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
from extra.assembly.amd.dsl import PDF_URLS
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
|
||||
def extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s: continue
|
||||
if re.match(r'^\d+ of \d+$', s): continue
|
||||
if re.match(r'^\d+\.\d+\..*Instructions', s): continue
|
||||
# Skip document headers (RDNA or CDNA)
|
||||
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
|
||||
if s.startswith('Notes') or s.startswith('Functional examples'): break
|
||||
# Track lambda definitions (e.g., BYTE_PERMUTE = lambda(data, sel) (...))
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
continue
|
||||
if s.startswith('if '): depth += 1
|
||||
elif s.startswith('endif'): depth = max(0, depth - 1)
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
|
||||
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
|
||||
)
|
||||
if is_code: result.append(s)
|
||||
return '\n'.join(result) if result else None
|
||||
|
||||
def _get_op_enums(arch: str) -> list:
|
||||
"""Dynamically load op enums from the arch-specific autogen module."""
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}")
|
||||
# Deterministic order: common enums first, then arch-specific
|
||||
enums = []
|
||||
for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp']:
|
||||
if hasattr(autogen, name): enums.append(getattr(autogen, name))
|
||||
return enums
|
||||
|
||||
def _parse_pseudocode_from_single_pdf(url: str, defined_ops: dict, OP_ENUMS: list) -> dict:
|
||||
"""Parse pseudocode from a single PDF."""
|
||||
import pdfplumber
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
pdf = pdfplumber.open(fetch(url))
|
||||
total_pages = len(pdf.pages)
|
||||
|
||||
page_cache = {}
|
||||
def get_page_text(i):
|
||||
if i not in page_cache: page_cache[i] = pdf.pages[i].extract_text() or ''
|
||||
return page_cache[i]
|
||||
|
||||
# Find the "Instructions" chapter - typically 10-40% through the document
|
||||
instr_start = None
|
||||
for i in range(int(total_pages * 0.1), int(total_pages * 0.5)):
|
||||
if re.search(r'Chapter \d+\.\s+Instructions\b', get_page_text(i)):
|
||||
instr_start = i
|
||||
break
|
||||
if instr_start is None: instr_start = total_pages // 3 # fallback
|
||||
|
||||
# Find end - stop at "Microcode Formats" chapter (typically 60-70% through)
|
||||
instr_end = total_pages
|
||||
search_starts = [int(total_pages * 0.6), int(total_pages * 0.5), instr_start]
|
||||
for start in search_starts:
|
||||
for i in range(start, min(start + 100, total_pages)):
|
||||
if re.search(r'Chapter \d+\.\s+Microcode Formats', get_page_text(i)):
|
||||
instr_end = i
|
||||
break
|
||||
if instr_end < total_pages: break
|
||||
|
||||
# Extract remaining pages (some already cached from chapter search)
|
||||
all_text = '\n'.join(get_page_text(i) for i in range(instr_start, instr_end))
|
||||
matches = list(INST_PATTERN.finditer(all_text))
|
||||
instructions: dict = {cls: {} for cls in OP_ENUMS}
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
name, opcode = match.group(1), int(match.group(2))
|
||||
key = (name, opcode)
|
||||
if key not in defined_ops: continue
|
||||
start = match.end()
|
||||
end = matches[i + 1].start() if i + 1 < len(matches) else start + 2000
|
||||
snippet = all_text[start:end].strip()
|
||||
if (pseudocode := extract_pseudocode(snippet)):
|
||||
# Assign to all enums that have this op (e.g., both VOPCOp and VOP3AOp)
|
||||
for enum_cls, enum_val in defined_ops[key]:
|
||||
instructions[enum_cls][enum_val] = pseudocode
|
||||
|
||||
return instructions
|
||||
|
||||
def parse_pseudocode_from_pdf(arch: str = "rdna3") -> dict:
|
||||
"""Parse pseudocode from PDF(s) for all ops. Returns {enum_cls: {op: pseudocode}}."""
|
||||
OP_ENUMS = _get_op_enums(arch)
|
||||
# Build a dict from (name, opcode) -> list of (enum_cls, op) tuples
|
||||
# Multiple enums can have the same op (e.g., VOPCOp and VOP3AOp both have V_CMP_* ops)
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
urls = PDF_URLS[arch]
|
||||
if isinstance(urls, str): urls = [urls]
|
||||
|
||||
# Parse all PDFs and merge (union of pseudocode)
|
||||
# Reverse order so newer PDFs (RDNA3.5, CDNA4) take priority
|
||||
instructions: dict = {cls: {} for cls in OP_ENUMS}
|
||||
for url in reversed(urls):
|
||||
result = _parse_pseudocode_from_single_pdf(url, defined_ops, OP_ENUMS)
|
||||
for cls, ops in result.items():
|
||||
for op, pseudocode in ops.items():
|
||||
if op in instructions[cls]:
|
||||
if instructions[cls][op] != pseudocode:
|
||||
print(f" Ignoring {op.name} from older PDF:")
|
||||
print(f" new: {instructions[cls][op]!r}")
|
||||
print(f" old: {pseudocode!r}")
|
||||
else:
|
||||
instructions[cls][op] = pseudocode
|
||||
|
||||
return instructions
|
||||
|
||||
def generate_gen_pcode(output_path: str = "extra/assembly/amd/autogen/rdna3/gen_pcode.py", arch: str = "rdna3"):
|
||||
"""Generate gen_pcode.py - compiled pseudocode functions for the emulator."""
|
||||
from pathlib import Path
|
||||
|
||||
OP_ENUMS = _get_op_enums(arch)
|
||||
|
||||
print("Parsing pseudocode from PDF...")
|
||||
by_cls = parse_pseudocode_from_pdf(arch)
|
||||
|
||||
total_found, total_ops = 0, 0
|
||||
for enum_cls in OP_ENUMS:
|
||||
total = sum(1 for op in enum_cls if op.name.startswith(('S_', 'V_')))
|
||||
found = len(by_cls.get(enum_cls, {}))
|
||||
total_found += found
|
||||
total_ops += total
|
||||
print(f"{enum_cls.__name__}: {found}/{total} ({100*found//total if total else 0}%)")
|
||||
print(f"Total: {total_found}/{total_ops} ({100*total_found//total_ops}%)")
|
||||
|
||||
print("\nCompiling to pseudocode functions...")
|
||||
# Build dynamic import line based on available enums
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
lines = [f'''# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pcode --arch {arch}
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.{arch} import {", ".join(enum_names)}
|
||||
from extra.assembly.amd.pcode import *
|
||||
''']
|
||||
|
||||
compiled_count, skipped_count = 0, 0
|
||||
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
pseudocode_dict = by_cls.get(enum_cls, {})
|
||||
if not pseudocode_dict: continue
|
||||
|
||||
fn_entries = []
|
||||
for op, pc in pseudocode_dict.items():
|
||||
if any(p in pc for p in UNSUPPORTED):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
code = compile_pseudocode(pc)
|
||||
# NOTE: Do NOT add more code.replace() hacks here. Fix issues properly in the DSL
|
||||
# (compile_pseudocode, helper functions, or Reg/TypedView classes) instead.
|
||||
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
|
||||
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
|
||||
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
|
||||
# (to unscale a denominator that was scaled).
|
||||
if op.name == 'V_DIV_FMAS_F32':
|
||||
code = code.replace(
|
||||
'D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op.name == 'V_DIV_FMAS_F64':
|
||||
code = code.replace(
|
||||
'D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
# V_DIV_SCALE_F32/F64: PDF page 463-464 has several bugs vs hardware behavior:
|
||||
# 1. Zero case: hardware sets VCC=1 (PDF doesn't)
|
||||
# 2. Denorm denom: hardware returns NaN (PDF says scale). VCC is set independently by exp diff check.
|
||||
# 3. Tiny numer (exp<=23): hardware sets VCC=1 (PDF doesn't)
|
||||
# 4. Result would be denorm: hardware doesn't scale, just sets VCC=1
|
||||
if op.name == 'V_DIV_SCALE_F32':
|
||||
# Fix 1: Set VCC=1 when zero operands produce NaN
|
||||
code = code.replace(
|
||||
'D0.f32 = float("nan")',
|
||||
'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
# Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs.
|
||||
# Insert at end of all branches, before the final result is used
|
||||
code = code.replace(
|
||||
'elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif False:\n pass # denorm check moved to end')
|
||||
# Add denorm check at the very end - this overrides D0 but preserves VCC
|
||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||
# Fix 3: Tiny numer should set VCC=1
|
||||
code = code.replace(
|
||||
'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
# Fix 4: S2/S1 would be denorm - don't scale, just set VCC
|
||||
code = code.replace(
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
if op.name == 'V_DIV_SCALE_F64':
|
||||
# Same fixes for f64 version
|
||||
code = code.replace(
|
||||
'D0.f64 = float("nan")',
|
||||
'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
code = code.replace(
|
||||
'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif False:\n pass # denorm check moved to end')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace(
|
||||
'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace(
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
# V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
|
||||
# When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf.
|
||||
if op.name == 'V_DIV_FIXUP_F32':
|
||||
code = code.replace(
|
||||
'D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||
if op.name == 'V_DIV_FIXUP_F64':
|
||||
code = code.replace(
|
||||
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
# V_TRIG_PREOP_F64: AMD pseudocode uses (x << shift) & mask but mask needs to extract TOP bits.
|
||||
# The PDF shows: result = 64'F((1201'B(2.0/PI)[1200:0] << shift) & 1201'0x1fffffffffffff)
|
||||
# Issues to fix:
|
||||
# 1. After left shift, the interesting bits are at the top, not bottom - need >> (1201-53)
|
||||
# 2. shift.u32 fails because shift is a plain int after * 53 - use int(shift)
|
||||
# 3. 64'F(...) means convert int to float (not interpret as bit pattern) - use float()
|
||||
if op.name == 'V_TRIG_PREOP_F64':
|
||||
code = code.replace(
|
||||
'result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||
# Detect flags for result handling
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
|
||||
# V_DIV_SCALE passes through S0 if no branch taken
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
# Instructions that use/modify PC
|
||||
has_pc = 'PC' in pc
|
||||
|
||||
# Generate function with indented body
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
|
||||
# Add original pseudocode as comment
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
# Only create Reg objects for registers actually used in the pseudocode
|
||||
combined = code + pc
|
||||
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
|
||||
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'),
|
||||
('PC', 'Reg(pc)')] # PC is passed in as byte address
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
# EXEC_LO/EXEC_HI need EXEC
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
# VCCZ/EXECZ need VCC/EXEC
|
||||
if 'VCCZ' in combined: used.add('VCC')
|
||||
if 'EXECZ' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
# VCCZ = 1 if VCC == 0, EXECZ = 1 if EXEC == 0
|
||||
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
# Add compiled pseudocode with markers
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'):
|
||||
lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
# Generate result dict - use raw params if Reg wasn't created
|
||||
d0_val = "D0._val" if 'D0' in used else "d0"
|
||||
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
|
||||
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
|
||||
if has_sdst:
|
||||
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
elif 'VCC' in used:
|
||||
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx:
|
||||
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
elif 'EXEC' in used:
|
||||
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp:
|
||||
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
if is_64:
|
||||
lines.append(" result['d0_64'] = True")
|
||||
if has_d1:
|
||||
lines.append(" result['d1'] = D1._val & 1")
|
||||
if has_pc:
|
||||
# Return new PC as absolute byte address, emulator will compute delta
|
||||
# Handle negative values (backward jumps): PC._val is stored as unsigned, convert to signed
|
||||
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
|
||||
lines.append(" result['new_pc'] = _pc # absolute byte address")
|
||||
lines.append(" return result")
|
||||
lines.append("")
|
||||
|
||||
fn_entries.append((op, fn_name))
|
||||
compiled_count += 1
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to compile {op.name}: {e}")
|
||||
skipped_count += 1
|
||||
|
||||
if fn_entries:
|
||||
lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name in fn_entries:
|
||||
lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
|
||||
# Add manually implemented V_WRITELANE_B32 (not in PDF pseudocode, requires special vgpr_write handling)
|
||||
# Only add for architectures that have VOP3Op (RDNA) not VOP3AOp/VOP3BOp (CDNA)
|
||||
if 'VOP3Op' in enum_names:
|
||||
lines.append('''
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
''')
|
||||
|
||||
lines.append('COMPILED_FUNCTIONS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,')
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS')
|
||||
|
||||
Path(output_path).write_text('\n'.join(lines))
|
||||
print(f"\nGenerated {output_path}: {compiled_count} compiled, {skipped_count} skipped")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Generate pseudocode functions from AMD ISA PDF")
|
||||
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3", help="Target architecture (default: rdna3)")
|
||||
args = parser.parse_args()
|
||||
if args.arch == "all":
|
||||
for arch in PDF_URLS.keys():
|
||||
generate_gen_pcode(output_path=f"extra/assembly/amd/autogen/{arch}/gen_pcode.py", arch=arch)
|
||||
else:
|
||||
generate_gen_pcode(output_path=f"extra/assembly/amd/autogen/{args.arch}/gen_pcode.py", arch=args.arch)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,54 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for the RDNA3 pseudocode DSL."""
|
||||
import unittest
|
||||
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64,
|
||||
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, MASK32, MASK64,
|
||||
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
||||
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
|
||||
from extra.assembly.amd.generate import compile_pseudocode, _expr
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
|
||||
|
||||
class ExecContext:
|
||||
"""Context for running compiled pseudocode (test-only)."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
|
||||
self.lane, self.laneId, self.literal = lane, lane, literal
|
||||
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
|
||||
self.VGPR = vgprs if vgprs is not None else {}
|
||||
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
|
||||
|
||||
def run(self, code: str):
|
||||
"""Execute compiled code."""
|
||||
import extra.assembly.amd.pcode as pcode_mod
|
||||
ns = {k: v for k, v in vars(pcode_mod).items() if not k.startswith('_') or k in ('_f32', '_i32', '_f16', '_i16', '_f64', '_i64', '_bf16', '_ibf16',
|
||||
'_div', '_sext', '_isnan', '_isquietnan', '_issignalnan', '_fma',
|
||||
'_gt_neg_zero', '_lt_neg_zero', '_signext', '_to_f16_bits', '_pack')}
|
||||
ns.update({
|
||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
|
||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
|
||||
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||
})
|
||||
exec(code, ns)
|
||||
def _sync(ctx_reg, ns_val):
|
||||
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
|
||||
else: ctx_reg._val = int(ns_val) & MASK64
|
||||
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
|
||||
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
|
||||
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
|
||||
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
|
||||
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
|
||||
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
|
||||
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
|
||||
|
||||
def result(self) -> dict:
|
||||
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
|
||||
|
||||
class TestReg(unittest.TestCase):
|
||||
def test_u32_read(self):
|
||||
r = Reg(0xDEADBEEF)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue