assembly/amd: factor out pdf generation

This commit is contained in:
George Hotz 2025-12-30 14:44:45 -05:00
commit ef5ee0f723
10 changed files with 826 additions and 927 deletions

View file

@ -1,4 +1,4 @@
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by dsl.py - do not edit
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by generate.py - do not edit
from enum import IntEnum
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField

View file

@ -1,5 +1,5 @@
# autogenerated by pcode.py - do not edit
# to regenerate: python -m extra.assembly.amd.pcode --arch cdna
# autogenerated by generate.py - do not edit
# to regenerate: python -m extra.assembly.amd.generate --arch cdna
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.cdna import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3POp, VOPCOp, VOP3AOp, VOP3BOp
@ -9678,32 +9678,13 @@ def _VOPCOp_V_CMPX_T_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGP
# Set the per-lane condition code to 1. Store the result into the EXEC mask and to VCC or a scalar register.
# EXEC.u64[laneId] = D0.u64[laneId] = 1'1U;
# // D0 = VCC in VOPC encoding.
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# src = DATA.u32;
D0 = Reg(d0)
VCC = Reg(vcc)
EXEC = Reg(exec_mask)
tmp = Reg(0)
laneId = lane
PC = Reg(pc)
# --- compiled pseudocode ---
EXEC.u64[laneId] = D0.u64[laneId] = 1
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
src = DATA.u32
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
@ -14177,32 +14158,13 @@ def _VOP3AOp_V_CMPX_T_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
# Set the per-lane condition code to 1. Store the result into the EXEC mask and to VCC or a scalar register.
# EXEC.u64[laneId] = D0.u64[laneId] = 1'1U;
# // D0 = VCC in VOPC encoding.
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32);
# tmp = MEM[addr].u32;
# src = DATA.u32;
D0 = Reg(d0)
VCC = Reg(vcc)
EXEC = Reg(exec_mask)
tmp = Reg(0)
laneId = lane
PC = Reg(pc)
# --- compiled pseudocode ---
EXEC.u64[laneId] = D0.u64[laneId] = 1
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
addr = CalcDsAddr(ADDR.b32, OFFSET0.b32, OFFSET1.b32)
tmp = Reg(MEM[addr].u32)
src = DATA.u32
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1

View file

@ -1,4 +1,4 @@
# autogenerated from AMD RDNA3.5 ISA PDF by dsl.py - do not edit
# autogenerated from AMD RDNA3.5 ISA PDF by generate.py - do not edit
from enum import IntEnum
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField

View file

@ -1,5 +1,5 @@
# autogenerated by pcode.py - do not edit
# to regenerate: python -m extra.assembly.amd.pcode --arch rdna3
# autogenerated by generate.py - do not edit
# to regenerate: python -m extra.assembly.amd.generate --arch rdna3
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.rdna3 import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp

View file

@ -1,4 +1,4 @@
# autogenerated from AMD RDNA4 ISA PDF by dsl.py - do not edit
# autogenerated from AMD RDNA4 ISA PDF by generate.py - do not edit
from enum import IntEnum
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField

View file

@ -1,5 +1,5 @@
# autogenerated by pcode.py - do not edit
# to regenerate: python -m extra.assembly.amd.pcode --arch rdna4
# autogenerated by generate.py - do not edit
# to regenerate: python -m extra.assembly.amd.generate --arch rdna4
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.rdna4 import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp

View file

@ -352,296 +352,3 @@ class Inst:
class Inst32(Inst): pass
class Inst64(Inst): pass
# ═══════════════════════════════════════════════════════════════════════════════
# CODE GENERATION: generates autogen/__init__.py by parsing AMD ISA PDFs
# Supports both RDNA3.5 and CDNA4 instruction set PDFs - auto-detects format
# ═══════════════════════════════════════════════════════════════════════════════
PDF_URLS = {
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content", # RDNA3.5
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
"cdna": ["https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf",
"https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf"],
}
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
FIELD_ORDER = {
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
def _parse_bits(s: str) -> tuple[int, int] | None:
import re
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
import re
fields = []
for row in table[1:]:
if not row or not row[0]: continue
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
if not (bits := _parse_bits(bits_str)): continue
enc_val, hi, lo = None, bits[0], bits[1]
if name == 'ENCODING' and row[2]:
# Handle both RDNA3 ('bXX) and CDNA4 (Must be: XX) encoding formats
if m := re.search(r"(?:'b|Must be:\s*)([01_]+)", row[2]):
enc_bits = m.group(1).replace('_', '')
enc_val = int(enc_bits, 2)
declared_width, actual_width = hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
fields.append((name, hi, lo, enc_val, ftype))
return fields
def _parse_single_pdf(url: str) -> dict:
"""Parse a single PDF and return raw data (formats, enums, src_enum, doc_name, is_cdna)."""
import re, pdfplumber
from tinygrad.helpers import fetch
pdf = pdfplumber.open(fetch(url))
# Auto-detect document type from first page
first_page_text = pdf.pages[0].extract_text() or ''
is_cdna4 = 'CDNA4' in first_page_text or 'CDNA 4' in first_page_text
is_cdna3 = 'CDNA3' in first_page_text or 'CDNA 3' in first_page_text or 'MI300' in first_page_text
is_cdna = is_cdna3 or is_cdna4
is_rdna4 = 'RDNA4' in first_page_text or 'RDNA 4' in first_page_text
is_rdna35 = 'RDNA3.5' in first_page_text or 'RDNA 3.5' in first_page_text # Check 3.5 before 3
is_rdna3 = not is_rdna35 and ('RDNA3' in first_page_text or 'RDNA 3' in first_page_text)
doc_name = "CDNA4" if is_cdna4 else "CDNA3" if is_cdna3 else "RDNA4" if is_rdna4 else "RDNA3.5" if is_rdna35 else "RDNA3" if is_rdna3 else "Unknown"
# Find the "Microcode Formats" section - search for SOP2 format definition
microcode_start = None
total_pages = len(pdf.pages)
# Search from likely locations (formats are typically 20-95% through the document - RDNA3 has them at ~25%)
for i in range(int(total_pages * 0.2), total_pages):
text = pdf.pages[i].extract_text() or ''
# Look for "X.Y.Z. SOP2" section header or "Chapter X. Microcode Formats"
if re.search(r'\d+\.\d+\.\d+\.\s+SOP2\b', text) or re.search(r'Chapter \d+\.\s+Microcode Formats', text):
microcode_start = i
break
if microcode_start is None: microcode_start = int(total_pages * 0.9)
pages = pdf.pages[microcode_start:microcode_start + 50]
page_texts = [p.extract_text() or '' for p in pages]
page_tables = [[t.extract() for t in p.find_tables()] for p in pages]
full_text = '\n'.join(page_texts)
# parse SSRC encoding from first page with VCC_LO
src_enum = dict(SRC_EXTRAS)
for text in page_texts[:10]:
if 'SSRC0' in text and 'VCC_LO' in text:
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
val, name = int(m.group(1)), m.group(2).rstrip('.:')
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
break
# parse opcode tables
enums: dict[str, dict[int, str]] = {}
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
enums[m.group(1) + "Op"] = ops
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
enums["VOPDOp"] = ops
enum_names = set(enums.keys())
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
def has_header_before_fields(text) -> bool:
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
# find format headers with their page indices
format_headers = []
for i, text in enumerate(page_texts):
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
fmt_name = m.group(1)
if is_cdna and fmt_name.isupper() and len(fmt_name) >= 2:
format_headers.append((fmt_name, i, m.start()))
elif m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < len(page_texts):
next_text = page_texts[i + 1].lstrip()
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
format_headers.append((fmt_name, i, m.start()))
# parse instruction formats
formats: dict[str, list] = {}
for fmt_name, page_idx, header_pos in format_headers:
if fmt_name in formats: continue
text, tables = page_texts[page_idx], page_tables[page_idx]
field_pos = text.find('Field Name', header_pos)
fields = None
for offset in range(3):
if page_idx + offset >= len(pages): break
if offset > 0 and has_header_before_fields(page_texts[page_idx + offset]): break
for t in page_tables[page_idx + offset] if offset > 0 or field_pos > header_pos else []:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
fields = f
break
if fields: break
if not fields and field_pos > header_pos:
for t in tables:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)):
fields = f
break
if not fields: continue
field_names = {f[0] for f in fields}
for pg_offset in range(1, 3):
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
for t in page_tables[page_idx + pg_offset]:
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
for ef in extra:
if ef[0] not in field_names:
fields.append(ef)
field_names.add(ef[0])
break
formats[fmt_name] = fields
# fix known PDF errors
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna}
def _merge_results(results: list[dict]) -> dict:
"""Merge multiple PDF parse results into a superset. Asserts if any conflicts."""
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "is_cdna": False}
for r in results:
merged["doc_names"].append(r["doc_name"])
merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"]
# Merge src_enum (union, assert no conflicts)
for val, name in r["src_enum"].items():
if val in merged["src_enum"]:
assert merged["src_enum"][val] == name, f"SrcEnum conflict: {val} = {merged['src_enum'][val]} vs {name}"
else:
merged["src_enum"][val] = name
# Merge enums (union of ops per enum, assert no conflicts)
for enum_name, ops in r["enums"].items():
if enum_name not in merged["enums"]: merged["enums"][enum_name] = {}
for val, name in ops.items():
if val in merged["enums"][enum_name]:
assert merged["enums"][enum_name][val] == name, f"{enum_name} conflict: {val} = {merged['enums'][enum_name][val]} vs {name}"
else:
merged["enums"][enum_name][val] = name
# Merge formats (union of fields, assert no bit position conflicts for same field name)
for fmt_name, fields in r["formats"].items():
if fmt_name not in merged["formats"]:
merged["formats"][fmt_name] = list(fields)
else:
existing = {f[0]: (f[1], f[2]) for f in merged["formats"][fmt_name]} # name -> (hi, lo)
for f in fields:
name, hi, lo = f[0], f[1], f[2]
if name in existing:
assert existing[name] == (hi, lo), f"Format {fmt_name} field {name} conflict: bits {existing[name]} vs ({hi}, {lo})"
else:
merged["formats"][fmt_name].append(f)
return merged
def generate(output_path: str | None = None, arch: str = "rdna3") -> dict:
"""Generate instruction definitions from AMD ISA PDF(s). Returns dict with formats for testing."""
urls = PDF_URLS[arch]
if isinstance(urls, str): urls = [urls]
# Parse all PDFs and merge
results = [_parse_single_pdf(url) for url in urls]
if len(results) == 1:
merged = results[0]
doc_name = merged["doc_name"]
else:
merged = _merge_results(results)
doc_name = "+".join(merged["doc_names"])
formats, enums, src_enum = merged["formats"], merged["enums"], merged["src_enum"]
# generate output
def enum_lines(name, items):
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by dsl.py - do not edit", "from enum import IntEnum",
"from typing import Annotated",
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"import functools", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
# Format-specific field defaults (verified against LLVM test vectors)
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
lines.append("# instruction formats")
for fmt_name, fields in sorted(formats.items()):
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
order = FIELD_ORDER.get(fmt_name, [])
lines.append(f"class {fmt_name}({base}):")
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
lines.append(f" encoding = {enc_str}")
if defaults := format_defaults.get(fmt_name):
lines.append(f" _defaults = {defaults}")
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
if ftype and ftype.endswith('Op'):
ann = f":Annotated[BitField, {ftype}]"
else:
ann = f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
lines.append("")
lines.append("# instruction helpers")
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
if fmt in ("VOP1", "VOP2", "VOPC"):
suffix = "_e32"
elif fmt == "VOP3" and op_val < 512:
suffix = "_e64"
else:
suffix = ""
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
else:
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
skip_exports = {'DPP8', 'DPP16'}
src_names = {name for _, name in src_enum.items()}
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports]
if "NULL" in src_names: lines.append("OFF = NULL\n")
if output_path is not None:
import pathlib
pathlib.Path(output_path).write_text('\n'.join(lines))
return {"formats": formats, "enums": enums, "src_enum": src_enum}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate instruction definitions from AMD ISA PDF")
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3", help="Target architecture (default: rdna3)")
args = parser.parse_args()
if args.arch == "all":
for arch in PDF_URLS.keys():
result = generate(f"extra/assembly/amd/autogen/{arch}/__init__.py", arch=arch)
print(f"{arch}: generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")
else:
result = generate(f"extra/assembly/amd/autogen/{args.arch}/__init__.py", arch=args.arch)
print(f"generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")

View file

@ -0,0 +1,772 @@
# PDF parsing and code generation for AMD ISA
# Generates both autogen/__init__.py (instruction formats) and gen_pcode.py (pseudocode functions)
# Usage: python -m extra.assembly.amd.pdf --arch rdna3
import re
PDF_URLS = {
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content", # RDNA3.5
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
"cdna": ["https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf",
"https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf"],
}
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
FIELD_ORDER = {
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# ═══════════════════════════════════════════════════════════════════════════════
# SHARED PDF PARSING INFRASTRUCTURE
# ═══════════════════════════════════════════════════════════════════════════════
class _LazyPageCache:
"""Lazy page text/table extractor with caching to avoid redundant PDF parsing."""
__slots__ = ('_pdf', '_offset', '_text_cache', '_table_cache')
def __init__(self, pdf, offset: int = 0):
self._pdf, self._offset, self._text_cache, self._table_cache = pdf, offset, {}, {}
def text(self, idx: int) -> str:
if idx not in self._text_cache: self._text_cache[idx] = self._pdf.pages[self._offset + idx].extract_text() or ''
return self._text_cache[idx]
def tables(self, idx: int) -> list:
if idx not in self._table_cache: self._table_cache[idx] = [t.extract() for t in self._pdf.pages[self._offset + idx].find_tables()]
return self._table_cache[idx]
def texts_range(self, start: int, end: int) -> list[str]: return [self.text(i) for i in range(start, end)]
def _detect_doc_type(first_page_text: str) -> tuple[bool, str]:
"""Detect document type from first page text. Returns (is_cdna, doc_name)."""
is_cdna4 = 'CDNA4' in first_page_text or 'CDNA 4' in first_page_text
is_cdna3 = 'CDNA3' in first_page_text or 'CDNA 3' in first_page_text or 'MI300' in first_page_text
is_cdna = is_cdna3 or is_cdna4
is_rdna4 = 'RDNA4' in first_page_text or 'RDNA 4' in first_page_text
is_rdna35 = 'RDNA3.5' in first_page_text or 'RDNA 3.5' in first_page_text
is_rdna3 = not is_rdna35 and ('RDNA3' in first_page_text or 'RDNA 3' in first_page_text)
doc_name = "CDNA4" if is_cdna4 else "CDNA3" if is_cdna3 else "RDNA4" if is_rdna4 else "RDNA3.5" if is_rdna35 else "RDNA3" if is_rdna3 else "Unknown"
return is_cdna, doc_name
def _find_chapter(cache: _LazyPageCache, total_pages: int, pattern: str, sample_pcts: list[float]) -> int | None:
"""Find chapter page by sampling at likely positions then searching nearby. Returns page index or None."""
for pct in sample_pcts:
idx = int(total_pages * pct)
if 0 <= idx < total_pages and re.search(pattern, cache.text(idx)): return idx
for pct in sample_pcts:
base = int(total_pages * pct)
for offset in range(-10, 11):
idx = base + offset
if 0 <= idx < total_pages and re.search(pattern, cache.text(idx)): return idx
return None
def _parse_bits(s: str) -> tuple[int, int] | None:
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
fields = []
for row in table[1:]:
if not row or not row[0]: continue
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
if not (bits := _parse_bits(bits_str)): continue
enc_val, hi, lo = None, bits[0], bits[1]
if name == 'ENCODING' and row[2]:
if m := re.search(r"(?:'b|Must be:\s*)([01_]+)", row[2]):
enc_bits = m.group(1).replace('_', '')
enc_val = int(enc_bits, 2)
declared_width, actual_width = hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
fields.append((name, hi, lo, enc_val, ftype))
return fields
# ═══════════════════════════════════════════════════════════════════════════════
# PARSE SINGLE PDF - extracts both format definitions AND pseudocode
# ═══════════════════════════════════════════════════════════════════════════════
def _parse_single_pdf(url: str) -> dict:
"""Parse a single PDF and return raw data for both dsl and pcode generation."""
import pdfplumber
from tinygrad.helpers import fetch
pdf = pdfplumber.open(fetch(url))
total_pages = len(pdf.pages)
cache = _LazyPageCache(pdf)
# Auto-detect document type from first page
is_cdna, doc_name = _detect_doc_type(cache.text(0))
# Find chapter positions using sampling
instr_pattern = r'Chapter \d+\.\s+Instructions\b'
microcode_pattern = r'Chapter \d+\.\s+Microcode Formats'
def is_microcode_page(text):
return (re.search(r'\d+\.\d+\.\d+\.\s+SOP2\s*\n.*Description', text) or
re.search(r'Chapter \d+\.\s+Microcode Formats\s*\n.*This section', text))
# Find microcode section with sampling (RDNA ~23%, CDNA ~93%)
microcode_start = None
for pct in [0.233, 0.93, 0.24, 0.92, 0.25, 0.22]:
sample = int(total_pages * pct)
if is_microcode_page(cache.text(sample)):
microcode_start = sample
while microcode_start > 0 and is_microcode_page(cache.text(microcode_start - 1)):
microcode_start -= 1
break
if microcode_start is None:
for i in range(int(total_pages * 0.15), total_pages):
if is_microcode_page(cache.text(i)):
microcode_start = i
break
if microcode_start is None: microcode_start = int(total_pages * 0.9)
# Find instructions chapter for pseudocode extraction
if is_cdna:
instr_start = _find_chapter(cache, total_pages, instr_pattern, [0.17, 0.18, 0.16, 0.19, 0.15])
instr_end = _find_chapter(cache, total_pages, microcode_pattern, [0.93, 0.92, 0.94, 0.91, 0.95])
else:
instr_start = _find_chapter(cache, total_pages, instr_pattern, [0.30, 0.31, 0.29, 0.32, 0.28])
instr_end = None # RDNA: Instructions goes to end
if instr_start is None: instr_start = int(total_pages * (0.17 if is_cdna else 0.30))
if instr_end is None: instr_end = total_pages
# ─── Parse format definitions from Microcode Formats chapter ───
fmt_cache = _LazyPageCache(pdf, microcode_start)
page_count = min(45, total_pages - microcode_start)
for idx, text in cache._text_cache.items():
if microcode_start <= idx < microcode_start + page_count: fmt_cache._text_cache[idx - microcode_start] = text
# Parse SSRC encoding
src_enum = dict(SRC_EXTRAS)
for i in range(2, 12):
text = fmt_cache.text(i)
if 'SSRC0' in text and 'VCC_LO' in text:
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
val, name = int(m.group(1)), m.group(2).rstrip('.:')
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
break
# Parse opcode tables
full_text = '\n'.join(fmt_cache.texts_range(2, page_count))
enums: dict[str, dict[int, str]] = {}
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
enums[m.group(1) + "Op"] = ops
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
enums["VOPDOp"] = ops
enum_names = set(enums.keys())
# Parse format field tables
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
def has_header_before_fields(text) -> bool:
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
# Find format headers with their page indices (CDNA has formats from page 1, RDNA from page 2)
format_headers = []
for i in range(1 if is_cdna else 2, page_count):
text = fmt_cache.text(i)
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
fmt_name = m.group(1)
if is_cdna and fmt_name.isupper() and len(fmt_name) >= 2:
format_headers.append((fmt_name, i, m.start()))
elif m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < page_count:
next_text = fmt_cache.text(i + 1).lstrip()
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
format_headers.append((fmt_name, i, m.start()))
# Parse instruction formats
formats: dict[str, list[tuple]] = {}
for fmt_name, page_idx, header_pos in format_headers:
if fmt_name in formats: continue
text, tables = fmt_cache.text(page_idx), fmt_cache.tables(page_idx)
field_pos = text.find('Field Name', header_pos)
fields = None
for offset in range(3):
if page_idx + offset >= page_count: break
if offset > 0 and has_header_before_fields(fmt_cache.text(page_idx + offset)): break
for t in fmt_cache.tables(page_idx + offset) if offset > 0 or field_pos > header_pos else []:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
fields = f
break
if fields: break
if not fields and field_pos > header_pos:
for t in tables:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)):
fields = f
break
if not fields: continue
field_names = {f[0] for f in fields}
# Look for continuation tables on subsequent pages
for pg_offset in range(1, 3):
if page_idx + pg_offset >= page_count or has_header_before_fields(fmt_cache.text(page_idx + pg_offset)): break
for t in fmt_cache.tables(page_idx + pg_offset):
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
for ef in extra:
if ef[0] not in field_names:
fields.append(ef)
field_names.add(ef[0])
break
formats[fmt_name] = fields
# Fix known PDF errors
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# ─── Parse pseudocode from Instructions chapter ───
instr_text = '\n'.join(cache.text(i) for i in range(instr_start, instr_end))
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna,
"instr_text": instr_text}
def _merge_results(results: list[dict]) -> dict:
"""Merge results from multiple PDFs (e.g., CDNA3 + CDNA4)."""
merged = {"formats": {}, "enums": {}, "src_enum": {}, "doc_names": [], "instr_texts": []}
for r in results:
merged["doc_names"].append(r["doc_name"])
merged["instr_texts"].append(r["instr_text"])
for val, name in r["src_enum"].items():
if val in merged["src_enum"]:
assert merged["src_enum"][val] == name, f"SrcEnum conflict: {val} = {merged['src_enum'][val]} vs {name}"
else: merged["src_enum"][val] = name
for enum_name, ops in r["enums"].items():
if enum_name not in merged["enums"]: merged["enums"][enum_name] = {}
for val, name in ops.items():
if val in merged["enums"][enum_name]:
assert merged["enums"][enum_name][val] == name, f"{enum_name} conflict: {val} = {merged['enums'][enum_name][val]} vs {name}"
else: merged["enums"][enum_name][val] = name
for fmt_name, fields in r["formats"].items():
if fmt_name not in merged["formats"]: merged["formats"][fmt_name] = list(fields)
else:
existing = {f[0]: (f[1], f[2]) for f in merged["formats"][fmt_name]}
for f in fields:
name, hi, lo = f[0], f[1], f[2]
if name in existing:
assert existing[name] == (hi, lo), f"Format {fmt_name} field {name} conflict: bits {existing[name]} vs ({hi}, {lo})"
else: merged["formats"][fmt_name].append(f)
return merged
# ═══════════════════════════════════════════════════════════════════════════════
# GENERATE __init__.py (instruction formats and enums)
# ═══════════════════════════════════════════════════════════════════════════════
def _generate_dsl(merged: dict, doc_name: str, output_path: str | None = None) -> str:
"""Generate instruction definitions code."""
formats, enums, src_enum = merged["formats"], merged["enums"], merged["src_enum"]
def enum_lines(name, items):
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
def field_key(f):
order = FIELD_ORDER.get(fmt_name, [])
return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by generate.py - do not edit", "from enum import IntEnum",
"from typing import Annotated",
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"import functools", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
lines.append("# instruction formats")
for fmt_name, fields in sorted(formats.items()):
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
lines.append(f"class {fmt_name}({base}):")
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
lines.append(f" encoding = {enc_str}")
if defaults := format_defaults.get(fmt_name):
lines.append(f" _defaults = {defaults}")
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
ann = f":Annotated[BitField, {ftype}]" if ftype and ftype.endswith('Op') else f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
lines.append("")
lines.append("# instruction helpers")
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
suffix = "_e32" if fmt in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt == "VOP3" and op_val < 512 else ""
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
else:
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
skip_exports = {'DPP8', 'DPP16'}
src_names = {name for _, name in src_enum.items()}
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports]
if "NULL" in src_names: lines.append("OFF = NULL\n")
content = '\n'.join(lines)
if output_path is not None:
import pathlib
pathlib.Path(output_path).write_text(content)
return content
# ═══════════════════════════════════════════════════════════════════════════════
# PSEUDOCODE COMPILER: pseudocode -> Python
# ═══════════════════════════════════════════════════════════════════════════════
def _expr(e: str) -> str:
"""Expression transform: minimal - just fix syntax differences."""
e = e.strip()
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e)
# Pack: { hi, lo } -> _pack(hi, lo)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
e = re.sub(r'\bB\(', '(', e)
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
# Constants
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
# Verilog bit slice: [start +: width] -> [start + width - 1 : start]
def convert_verilog_slice(m):
start, width = m.group(1).strip(), m.group(2).strip()
return f'[({start}) + ({width}) - 1 : ({start})]'
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
# Recursively process bracket contents
def process_brackets(s):
result, i = [], 0
while i < len(s):
if s[i] == '[':
depth, start = 1, i + 1
j = start
while j < len(s) and depth > 0:
if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1
j += 1
inner = _expr(s[start:j-1])
result.append('[' + inner + ']')
i = j
else:
result.append(s[i])
i += 1
return ''.join(result)
e = process_brackets(e)
# Ternary: a ? b : c -> (b if a else c)
while '?' in e:
depth, bracket, q = 0, 0, -1
for i, c in enumerate(e):
if c == '(': depth += 1
elif c == ')': depth -= 1
elif c == '[': bracket += 1
elif c == ']': bracket -= 1
elif c == '?' and depth == 0 and bracket == 0: q = i; break
if q < 0: break
depth, bracket, col = 0, 0, -1
for i in range(q + 1, len(e)):
if e[i] == '(': depth += 1
elif e[i] == ')': depth -= 1
elif e[i] == '[': bracket += 1
elif e[i] == ']': bracket -= 1
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
if col < 0: break
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
e = f'(({t}) if ({cond}) else ({f}))'
return e
def _assign(lhs: str, rhs: str) -> str:
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
def compile_pseudocode(pseudocode: str) -> str:
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
raw_lines = pseudocode.strip().split('\n')
joined_lines: list[str] = []
for line in raw_lines:
line = line.strip()
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
else:
joined_lines.append(line)
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line == 'else':
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + "else:")
indent += 1
need_pass = True
elif line.startswith('endif'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass = False
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1])
lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line:
lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = lhs.strip(), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
# ═══════════════════════════════════════════════════════════════════════════════
# GENERATE gen_pcode.py (compiled pseudocode functions)
# ═══════════════════════════════════════════════════════════════════════════════
def _extract_pseudocode(text: str) -> str | None:
"""Extract pseudocode from an instruction description snippet."""
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
for line in lines:
s = line.strip()
if not s: continue
if re.match(r'^\d+ of \d+$', s): continue
if re.match(r'^\d+\.\d+\..*Instructions', s): continue
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
if s.startswith('Notes') or s.startswith('Functional examples'): break
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
continue
if s.startswith('if '): depth += 1
elif s.startswith('endif'): depth = max(0, depth - 1)
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
)
if is_code: result.append(s)
return '\n'.join(result) if result else None
def _parse_pseudocode(instr_texts: list[str], defined_ops: dict, OP_ENUMS: list) -> dict:
"""Parse pseudocode from instruction text(s). Returns {enum_cls: {op: pseudocode}}."""
instructions: dict = {cls: {} for cls in OP_ENUMS}
# Process in reverse order so newer PDFs take priority
for instr_text in reversed(instr_texts):
matches = list(INST_PATTERN.finditer(instr_text))
for i, match in enumerate(matches):
name, opcode = match.group(1), int(match.group(2))
key = (name, opcode)
if key not in defined_ops: continue
start = match.end()
end = matches[i + 1].start() if i + 1 < len(matches) else start + 2000
snippet = instr_text[start:end].strip()
if (pseudocode := _extract_pseudocode(snippet)):
for enum_cls, enum_val in defined_ops[key]:
if enum_val not in instructions[enum_cls]:
instructions[enum_cls][enum_val] = pseudocode
return instructions
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask', 'S1[i', 'C.i32', 'S[i]', 'in[', 'if n.', 'DST.u32', 'addrd = DST', 'addr = DST']
def _generate_pcode(instr_texts: list[str], arch: str, output_path: str | None = None) -> tuple[int, int]:
"""Generate compiled pseudocode functions. Returns (compiled_count, skipped_count)."""
import importlib
# Load op enums from autogen module
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}")
OP_ENUMS = []
for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp']:
if hasattr(autogen, name): OP_ENUMS.append(getattr(autogen, name))
# Build defined_ops mapping
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
by_cls = _parse_pseudocode(instr_texts, defined_ops, OP_ENUMS)
# Report coverage
total_found, total_ops = 0, 0
for enum_cls in OP_ENUMS:
total = sum(1 for op in enum_cls if op.name.startswith(('S_', 'V_')))
found = len(by_cls.get(enum_cls, {}))
total_found += found
total_ops += total
print(f"{enum_cls.__name__}: {found}/{total} ({100*found//total if total else 0}%)")
print(f"Total: {total_found}/{total_ops} ({100*total_found//total_ops}%)")
# Generate code
enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by generate.py - do not edit
# to regenerate: python -m extra.assembly.amd.generate --arch {arch}
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.{arch} import {", ".join(enum_names)}
from extra.assembly.amd.pcode import *
''']
compiled_count, skipped_count = 0, 0
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
pseudocode_dict = by_cls.get(enum_cls, {})
if not pseudocode_dict: continue
fn_entries = []
for op, pc in pseudocode_dict.items():
if any(p in pc for p in UNSUPPORTED):
skipped_count += 1
continue
try:
code = compile_pseudocode(pc)
# Hardware behavior fixes (see pcode.py for detailed comments)
if op.name == 'V_DIV_FMAS_F32':
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op.name == 'V_DIV_FMAS_F64':
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
if op.name == 'V_DIV_SCALE_F32':
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass # denorm check moved to end')
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
if op.name == 'V_DIV_SCALE_F64':
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass # denorm check moved to end')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
if op.name == 'V_DIV_FIXUP_F32':
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
if op.name == 'V_DIV_FIXUP_F64':
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
if op.name == 'V_TRIG_PREOP_F64':
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
# Detect flags for result handling
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
has_d1 = '{ D1' in pc
if has_d1: is_64 = True
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
has_pc = 'PC' in pc
fn_name = f"_{cls_name}_{op.name}"
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
for pc_line in pc.split('\n'):
lines.append(f" # {pc_line}")
combined = code + pc
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'), ('PC', 'Reg(pc)')]
used = {name for name, _ in regs if name in combined}
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
if 'VCCZ' in combined: used.add('VCC')
if 'EXECZ' in combined: used.add('EXEC')
for name, init in regs:
if name in used: lines.append(f" {name} = {init}")
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
lines.append(" # --- compiled pseudocode ---")
for line in code.split('\n'):
lines.append(f" {line}")
lines.append(" # --- end pseudocode ---")
d0_val = "D0._val" if 'D0' in used else "d0"
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
if has_sdst:
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
elif 'VCC' in used:
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
if is_cmpx:
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
elif 'EXEC' in used:
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
if is_cmp:
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
if is_64:
lines.append(" result['d0_64'] = True")
if has_d1:
lines.append(" result['d1'] = D1._val & 1")
if has_pc:
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
lines.append(" result['new_pc'] = _pc # absolute byte address")
lines.append(" return result")
lines.append("")
fn_entries.append((op, fn_name))
compiled_count += 1
except Exception as e:
print(f" Warning: Failed to compile {op.name}: {e}")
skipped_count += 1
if fn_entries:
lines.append(f'{cls_name}_FUNCTIONS = {{')
for op, fn_name in fn_entries:
lines.append(f" {cls_name}.{op.name}: {fn_name},")
lines.append('}')
lines.append('')
# Add V_WRITELANE_B32 for RDNA
if 'VOP3Op' in enum_names:
lines.append('''
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
''')
lines.append('COMPILED_FUNCTIONS = {')
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,')
lines.append('}')
lines.append('')
lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS')
if output_path is not None:
from pathlib import Path
Path(output_path).write_text('\n'.join(lines))
return compiled_count, skipped_count
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN ENTRY POINT
# ═══════════════════════════════════════════════════════════════════════════════
def generate(arch: str = "rdna3", output_dir: str = "extra/assembly/amd/autogen"):
"""Generate both __init__.py and gen_pcode.py for the given architecture."""
urls = PDF_URLS[arch]
if isinstance(urls, str): urls = [urls]
print(f"Parsing PDF(s) for {arch}...")
results = [_parse_single_pdf(url) for url in urls]
if len(results) == 1:
merged = results[0]
doc_name = merged["doc_name"]
instr_texts = [merged["instr_text"]]
else:
merged = _merge_results(results)
doc_name = "+".join(merged["doc_names"])
instr_texts = merged["instr_texts"]
# Generate __init__.py first (needed for pcode generation)
init_path = f"{output_dir}/{arch}/__init__.py"
_generate_dsl(merged, doc_name, init_path)
print(f"Generated {init_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} opcode enums + {len(merged['formats'])} format classes")
# Generate gen_pcode.py
print("\nCompiling pseudocode functions...")
pcode_path = f"{output_dir}/{arch}/gen_pcode.py"
compiled, skipped = _generate_pcode(instr_texts, arch, pcode_path)
print(f"Generated {pcode_path}: {compiled} compiled, {skipped} skipped")
if __name__ == "__main__":
import argparse, subprocess, sys
parser = argparse.ArgumentParser(description="Generate AMD ISA definitions from PDF")
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3", help="Target architecture")
args = parser.parse_args()
if args.arch == "all":
procs = [subprocess.Popen([sys.executable, "-m", "extra.assembly.amd.generate", "--arch", arch]) for arch in PDF_URLS.keys()]
for p in procs: p.wait()
else:
generate(arch=args.arch)

View file

@ -295,7 +295,7 @@ def signext_from_bit(val, bit):
__all__ = [
# Classes
'Reg', 'SliceProxy', 'TypedView', 'ExecContext', 'compile_pseudocode',
'Reg', 'SliceProxy', 'TypedView',
# Pack functions
'_pack', '_pack32', 'pack', 'pack32',
# Constants
@ -623,588 +623,3 @@ class Reg:
def __ge__(s, o): return s._val >= int(o)
def __eq__(s, o): return s._val == int(o)
def __ne__(s, o): return s._val != int(o)
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
# ═══════════════════════════════════════════════════════════════════════════════
def compile_pseudocode(pseudocode: str) -> str:
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
# Join continuation lines (lines ending with || or && or open paren)
raw_lines = pseudocode.strip().split('\n')
joined_lines: list[str] = []
for line in raw_lines:
line = line.strip()
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
else:
joined_lines.append(line)
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
# Control flow - only need pass before outdent (endif/endfor/else/elsif)
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line == 'else':
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + "else:")
indent += 1
need_pass = True
elif line.startswith('endif'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass = False
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
# Handle tuple unpacking: { D1.u1, D0.u64 } = expr
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1])
lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
# Compound assignment
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line:
lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = lhs.strip(), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
# CLZ/CTZ pattern: assignment of loop var to tmp/D0.i32 in first-match loop needs break
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
# If we ended with a control statement that needs a body, add pass
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str:
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
def _expr(e: str) -> str:
"""Expression transform: minimal - just fix syntax differences."""
e = e.strip()
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e)
# Pack: { hi, lo } -> _pack(hi, lo)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
e = re.sub(r'\bB\(', '(', e) # Bare B( without digit prefix
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
# Remove redundant type suffix after lane access: VCC.u64[laneId].u64 -> VCC.u64[laneId]
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
# Constants - INF is defined as an object supporting .f32/.f64 access
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
# Verilog bit slice syntax: [start +: width] -> extract width bits starting at start
# Convert to Python slice: [start + width - 1 : start]
def convert_verilog_slice(m):
start, width = m.group(1).strip(), m.group(2).strip()
# Convert to high:low slice format
return f'[({start}) + ({width}) - 1 : ({start})]'
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
# Recursively process bracket contents to handle nested ternaries like S1.u32[x ? a : b]
def process_brackets(s):
result, i = [], 0
while i < len(s):
if s[i] == '[':
# Find matching ]
depth, start = 1, i + 1
j = start
while j < len(s) and depth > 0:
if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1
j += 1
inner = _expr(s[start:j-1]) # Recursively process bracket content
result.append('[' + inner + ']')
i = j
else:
result.append(s[i])
i += 1
return ''.join(result)
e = process_brackets(e)
# Ternary: a ? b : c -> (b if a else c)
while '?' in e:
depth, bracket, q = 0, 0, -1
for i, c in enumerate(e):
if c == '(': depth += 1
elif c == ')': depth -= 1
elif c == '[': bracket += 1
elif c == ']': bracket -= 1
elif c == '?' and depth == 0 and bracket == 0: q = i; break
if q < 0: break
depth, bracket, col = 0, 0, -1
for i in range(q + 1, len(e)):
if e[i] == '(': depth += 1
elif e[i] == ')': depth -= 1
elif e[i] == '[': bracket += 1
elif e[i] == ']': bracket -= 1
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
if col < 0: break
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
e = f'(({t}) if ({cond}) else ({f}))'
return e
# ═══════════════════════════════════════════════════════════════════════════════
# EXECUTION CONTEXT
# ═══════════════════════════════════════════════════════════════════════════════
class ExecContext:
"""Context for running compiled pseudocode."""
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
self.D0, self.D1 = Reg(d0), Reg(0)
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
self.lane, self.laneId, self.literal = lane, lane, literal
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
self.VGPR = vgprs if vgprs is not None else {}
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
def run(self, code: str):
"""Execute compiled code."""
# Start with module globals (helpers, aliases), then add instance-specific bindings
ns = dict(globals())
ns.update({
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
'tmp': self.tmp, 'saveexec': self.saveexec,
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
})
exec(code, ns)
# Sync rebinds: if register was reassigned to new Reg or value, copy it back
def _sync(ctx_reg, ns_val):
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
else: ctx_reg._val = int(ns_val) & MASK64
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
def result(self) -> dict:
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
# ═══════════════════════════════════════════════════════════════════════════════
# PDF EXTRACTION AND CODE GENERATION
# ═══════════════════════════════════════════════════════════════════════════════
from extra.assembly.amd.dsl import PDF_URLS
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
def extract_pseudocode(text: str) -> str | None:
"""Extract pseudocode from an instruction description snippet."""
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
for line in lines:
s = line.strip()
if not s: continue
if re.match(r'^\d+ of \d+$', s): continue
if re.match(r'^\d+\.\d+\..*Instructions', s): continue
# Skip document headers (RDNA or CDNA)
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
if s.startswith('Notes') or s.startswith('Functional examples'): break
# Track lambda definitions (e.g., BYTE_PERMUTE = lambda(data, sel) (...))
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
continue
if s.startswith('if '): depth += 1
elif s.startswith('endif'): depth = max(0, depth - 1)
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
)
if is_code: result.append(s)
return '\n'.join(result) if result else None
def _get_op_enums(arch: str) -> list:
"""Dynamically load op enums from the arch-specific autogen module."""
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}")
# Deterministic order: common enums first, then arch-specific
enums = []
for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp']:
if hasattr(autogen, name): enums.append(getattr(autogen, name))
return enums
def _parse_pseudocode_from_single_pdf(url: str, defined_ops: dict, OP_ENUMS: list) -> dict:
"""Parse pseudocode from a single PDF."""
import pdfplumber
from tinygrad.helpers import fetch
pdf = pdfplumber.open(fetch(url))
total_pages = len(pdf.pages)
page_cache = {}
def get_page_text(i):
if i not in page_cache: page_cache[i] = pdf.pages[i].extract_text() or ''
return page_cache[i]
# Find the "Instructions" chapter - typically 10-40% through the document
instr_start = None
for i in range(int(total_pages * 0.1), int(total_pages * 0.5)):
if re.search(r'Chapter \d+\.\s+Instructions\b', get_page_text(i)):
instr_start = i
break
if instr_start is None: instr_start = total_pages // 3 # fallback
# Find end - stop at "Microcode Formats" chapter (typically 60-70% through)
instr_end = total_pages
search_starts = [int(total_pages * 0.6), int(total_pages * 0.5), instr_start]
for start in search_starts:
for i in range(start, min(start + 100, total_pages)):
if re.search(r'Chapter \d+\.\s+Microcode Formats', get_page_text(i)):
instr_end = i
break
if instr_end < total_pages: break
# Extract remaining pages (some already cached from chapter search)
all_text = '\n'.join(get_page_text(i) for i in range(instr_start, instr_end))
matches = list(INST_PATTERN.finditer(all_text))
instructions: dict = {cls: {} for cls in OP_ENUMS}
for i, match in enumerate(matches):
name, opcode = match.group(1), int(match.group(2))
key = (name, opcode)
if key not in defined_ops: continue
start = match.end()
end = matches[i + 1].start() if i + 1 < len(matches) else start + 2000
snippet = all_text[start:end].strip()
if (pseudocode := extract_pseudocode(snippet)):
# Assign to all enums that have this op (e.g., both VOPCOp and VOP3AOp)
for enum_cls, enum_val in defined_ops[key]:
instructions[enum_cls][enum_val] = pseudocode
return instructions
def parse_pseudocode_from_pdf(arch: str = "rdna3") -> dict:
"""Parse pseudocode from PDF(s) for all ops. Returns {enum_cls: {op: pseudocode}}."""
OP_ENUMS = _get_op_enums(arch)
# Build a dict from (name, opcode) -> list of (enum_cls, op) tuples
# Multiple enums can have the same op (e.g., VOPCOp and VOP3AOp both have V_CMP_* ops)
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
urls = PDF_URLS[arch]
if isinstance(urls, str): urls = [urls]
# Parse all PDFs and merge (union of pseudocode)
# Reverse order so newer PDFs (RDNA3.5, CDNA4) take priority
instructions: dict = {cls: {} for cls in OP_ENUMS}
for url in reversed(urls):
result = _parse_pseudocode_from_single_pdf(url, defined_ops, OP_ENUMS)
for cls, ops in result.items():
for op, pseudocode in ops.items():
if op in instructions[cls]:
if instructions[cls][op] != pseudocode:
print(f" Ignoring {op.name} from older PDF:")
print(f" new: {instructions[cls][op]!r}")
print(f" old: {pseudocode!r}")
else:
instructions[cls][op] = pseudocode
return instructions
def generate_gen_pcode(output_path: str = "extra/assembly/amd/autogen/rdna3/gen_pcode.py", arch: str = "rdna3"):
"""Generate gen_pcode.py - compiled pseudocode functions for the emulator."""
from pathlib import Path
OP_ENUMS = _get_op_enums(arch)
print("Parsing pseudocode from PDF...")
by_cls = parse_pseudocode_from_pdf(arch)
total_found, total_ops = 0, 0
for enum_cls in OP_ENUMS:
total = sum(1 for op in enum_cls if op.name.startswith(('S_', 'V_')))
found = len(by_cls.get(enum_cls, {}))
total_found += found
total_ops += total
print(f"{enum_cls.__name__}: {found}/{total} ({100*found//total if total else 0}%)")
print(f"Total: {total_found}/{total_ops} ({100*total_found//total_ops}%)")
print("\nCompiling to pseudocode functions...")
# Build dynamic import line based on available enums
enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by pcode.py - do not edit
# to regenerate: python -m extra.assembly.amd.pcode --arch {arch}
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.{arch} import {", ".join(enum_names)}
from extra.assembly.amd.pcode import *
''']
compiled_count, skipped_count = 0, 0
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
pseudocode_dict = by_cls.get(enum_cls, {})
if not pseudocode_dict: continue
fn_entries = []
for op, pc in pseudocode_dict.items():
if any(p in pc for p in UNSUPPORTED):
skipped_count += 1
continue
try:
code = compile_pseudocode(pc)
# NOTE: Do NOT add more code.replace() hacks here. Fix issues properly in the DSL
# (compile_pseudocode, helper functions, or Reg/TypedView classes) instead.
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
# (to unscale a denominator that was scaled).
if op.name == 'V_DIV_FMAS_F32':
code = code.replace(
'D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op.name == 'V_DIV_FMAS_F64':
code = code.replace(
'D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
# V_DIV_SCALE_F32/F64: PDF page 463-464 has several bugs vs hardware behavior:
# 1. Zero case: hardware sets VCC=1 (PDF doesn't)
# 2. Denorm denom: hardware returns NaN (PDF says scale). VCC is set independently by exp diff check.
# 3. Tiny numer (exp<=23): hardware sets VCC=1 (PDF doesn't)
# 4. Result would be denorm: hardware doesn't scale, just sets VCC=1
if op.name == 'V_DIV_SCALE_F32':
# Fix 1: Set VCC=1 when zero operands produce NaN
code = code.replace(
'D0.f32 = float("nan")',
'VCC = Reg(0x1); D0.f32 = float("nan")')
# Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs.
# Insert at end of all branches, before the final result is used
code = code.replace(
'elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif False:\n pass # denorm check moved to end')
# Add denorm check at the very end - this overrides D0 but preserves VCC
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
# Fix 3: Tiny numer should set VCC=1
code = code.replace(
'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
# Fix 4: S2/S1 would be denorm - don't scale, just set VCC
code = code.replace(
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
if op.name == 'V_DIV_SCALE_F64':
# Same fixes for f64 version
code = code.replace(
'D0.f64 = float("nan")',
'VCC = Reg(0x1); D0.f64 = float("nan")')
code = code.replace(
'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif False:\n pass # denorm check moved to end')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace(
'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
code = code.replace(
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
# V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
# When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf.
if op.name == 'V_DIV_FIXUP_F32':
code = code.replace(
'D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
if op.name == 'V_DIV_FIXUP_F64':
code = code.replace(
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
# V_TRIG_PREOP_F64: AMD pseudocode uses (x << shift) & mask but mask needs to extract TOP bits.
# The PDF shows: result = 64'F((1201'B(2.0/PI)[1200:0] << shift) & 1201'0x1fffffffffffff)
# Issues to fix:
# 1. After left shift, the interesting bits are at the top, not bottom - need >> (1201-53)
# 2. shift.u32 fails because shift is a plain int after * 53 - use int(shift)
# 3. 64'F(...) means convert int to float (not interpret as bit pattern) - use float()
if op.name == 'V_TRIG_PREOP_F64':
code = code.replace(
'result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
# Detect flags for result handling
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
has_d1 = '{ D1' in pc
if has_d1: is_64 = True
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
# V_DIV_SCALE passes through S0 if no branch taken
is_div_scale = 'DIV_SCALE' in op.name
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
# Instructions that use/modify PC
has_pc = 'PC' in pc
# Generate function with indented body
fn_name = f"_{cls_name}_{op.name}"
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
# Add original pseudocode as comment
for pc_line in pc.split('\n'):
lines.append(f" # {pc_line}")
# Only create Reg objects for registers actually used in the pseudocode
combined = code + pc
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'),
('PC', 'Reg(pc)')] # PC is passed in as byte address
used = {name for name, _ in regs if name in combined}
# EXEC_LO/EXEC_HI need EXEC
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
# VCCZ/EXECZ need VCC/EXEC
if 'VCCZ' in combined: used.add('VCC')
if 'EXECZ' in combined: used.add('EXEC')
for name, init in regs:
if name in used: lines.append(f" {name} = {init}")
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
# VCCZ = 1 if VCC == 0, EXECZ = 1 if EXEC == 0
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
# Add compiled pseudocode with markers
lines.append(" # --- compiled pseudocode ---")
for line in code.split('\n'):
lines.append(f" {line}")
lines.append(" # --- end pseudocode ---")
# Generate result dict - use raw params if Reg wasn't created
d0_val = "D0._val" if 'D0' in used else "d0"
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
if has_sdst:
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
elif 'VCC' in used:
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
if is_cmpx:
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
elif 'EXEC' in used:
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
if is_cmp:
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
if is_64:
lines.append(" result['d0_64'] = True")
if has_d1:
lines.append(" result['d1'] = D1._val & 1")
if has_pc:
# Return new PC as absolute byte address, emulator will compute delta
# Handle negative values (backward jumps): PC._val is stored as unsigned, convert to signed
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
lines.append(" result['new_pc'] = _pc # absolute byte address")
lines.append(" return result")
lines.append("")
fn_entries.append((op, fn_name))
compiled_count += 1
except Exception as e:
print(f" Warning: Failed to compile {op.name}: {e}")
skipped_count += 1
if fn_entries:
lines.append(f'{cls_name}_FUNCTIONS = {{')
for op, fn_name in fn_entries:
lines.append(f" {cls_name}.{op.name}: {fn_name},")
lines.append('}')
lines.append('')
# Add manually implemented V_WRITELANE_B32 (not in PDF pseudocode, requires special vgpr_write handling)
# Only add for architectures that have VOP3Op (RDNA) not VOP3AOp/VOP3BOp (CDNA)
if 'VOP3Op' in enum_names:
lines.append('''
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
''')
lines.append('COMPILED_FUNCTIONS = {')
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,')
lines.append('}')
lines.append('')
lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS')
Path(output_path).write_text('\n'.join(lines))
print(f"\nGenerated {output_path}: {compiled_count} compiled, {skipped_count} skipped")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate pseudocode functions from AMD ISA PDF")
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3", help="Target architecture (default: rdna3)")
args = parser.parse_args()
if args.arch == "all":
for arch in PDF_URLS.keys():
generate_gen_pcode(output_path=f"extra/assembly/amd/autogen/{arch}/gen_pcode.py", arch=arch)
else:
generate_gen_pcode(output_path=f"extra/assembly/amd/autogen/{args.arch}/gen_pcode.py", arch=args.arch)

View file

@ -1,11 +1,54 @@
#!/usr/bin/env python3
"""Tests for the RDNA3 pseudocode DSL."""
import unittest
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64,
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, MASK32, MASK64,
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
from extra.assembly.amd.generate import compile_pseudocode, _expr
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
class ExecContext:
"""Context for running compiled pseudocode (test-only)."""
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
self.D0, self.D1 = Reg(d0), Reg(0)
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
self.lane, self.laneId, self.literal = lane, lane, literal
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
self.VGPR = vgprs if vgprs is not None else {}
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
def run(self, code: str):
"""Execute compiled code."""
import extra.assembly.amd.pcode as pcode_mod
ns = {k: v for k, v in vars(pcode_mod).items() if not k.startswith('_') or k in ('_f32', '_i32', '_f16', '_i16', '_f64', '_i64', '_bf16', '_ibf16',
'_div', '_sext', '_isnan', '_isquietnan', '_issignalnan', '_fma',
'_gt_neg_zero', '_lt_neg_zero', '_signext', '_to_f16_bits', '_pack')}
ns.update({
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
'tmp': self.tmp, 'saveexec': self.saveexec,
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
})
exec(code, ns)
def _sync(ctx_reg, ns_val):
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
else: ctx_reg._val = int(ns_val) & MASK64
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
def result(self) -> dict:
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
class TestReg(unittest.TestCase):
def test_u32_read(self):
r = Reg(0xDEADBEEF)