Revert "hotfix: skip test/amd in macpytest" (#14704)

* Revert "hotfix: skip test/amd in macpytest"

This reverts commit b7dade2adf.

* no llvm subprocess

* simpler

* sys.exec

* cleanup

* process safe

* diag

* arm ftz support

* 5 sec

* this one
This commit is contained in:
George Hotz 2026-02-13 08:00:24 +08:00 committed by GitHub
commit d3adb8428e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 215 additions and 339 deletions

View file

@ -56,7 +56,7 @@ jobs:
- name: Run pytest -nauto
run: |
source /tmp/tinygrad_pytest_ci/bin/activate
pytest -nauto --ignore test/amd/ --durations=20
pytest -nauto --durations=20
testmacbenchmark:
name: Mac Benchmark

View file

@ -1,6 +1,9 @@
"""Shared test helpers for RDNA3 tests."""
import shutil
"""Shared test helpers for AMD tests."""
import ctypes
from dataclasses import dataclass
from tinygrad.helpers import unwrap
from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader
@dataclass
class KernelInfo:
@ -11,19 +14,6 @@ class KernelInfo:
buf_idxs: list[int] # indices into shared buffer pool
buf_sizes: list[int] # sizes for each buffer index
# LLVM tool detection (shared across test files)
def get_llvm_mc():
"""Find llvm-mc executable, preferring newer versions."""
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']:
if shutil.which(p): return p
raise FileNotFoundError("llvm-mc not found")
def get_llvm_objdump():
"""Find llvm-objdump executable, preferring newer versions."""
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
if shutil.which(p): return p
raise FileNotFoundError("llvm-objdump not found")
ARCH_TO_TARGET:dict[str, list[str]] = {
"rdna3":["gfx1100"],
"rdna4":["gfx1200"],
@ -35,4 +25,107 @@ TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items
def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
def get_mattr(arch:str) -> str:
return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch]
return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch]
# LLVM in-process assembler/disassembler (replaces llvm-mc and llvm-objdump subprocesses)
_SENTINEL = b'\xde\xad\xbe\xef'
_SENTINEL_ASM = '.byte 0xde, 0xad, 0xbe, 0xef'
def _cerr(): return ctypes.pointer(ctypes.pointer(ctypes.c_char()))
def _expect(x, err, ret=None):
if x: raise RuntimeError(unwrap(ctypes.cast(err.contents, ctypes.c_char_p).value).decode() if not isinstance(err, str) else err)
return ret
def _init_llvm():
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter', 'Disassembler']:
getattr(llvm, f'LLVMInitializeAMDGPU{component}')()
def _create_target_machine(mcpu:str, mattr:str) -> llvm.LLVMTargetMachineRef:
target = _expect(llvm.LLVMGetTargetFromTriple(b'amdgcn-amd-amdhsa', ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=_cerr()), err, tgt)
return llvm.LLVMCreateTargetMachine(target, b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(),
llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocDefault, llvm.LLVMCodeModelDefault)
def _emit_obj(asm_text:str, mcpu:str, mattr:str, diag_errors:list[str]|None=None) -> bytes:
"""Assemble raw asm text into an ELF object using LLVM in-process."""
_init_llvm()
tm = _create_target_machine(mcpu, mattr)
ctx = llvm.LLVMContextCreate()
try:
errors = diag_errors if diag_errors is not None else []
@llvm.LLVMDiagnosticHandler
def handle_diag(diag_ref, _arg):
if llvm.LLVMGetDiagInfoSeverity(diag_ref) == llvm.LLVMDSError:
errors.append(ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode())
llvm.LLVMContextSetDiagnosticHandler(ctx, handle_diag, None)
mod = llvm.LLVMModuleCreateWithNameInContext(b'asm', ctx)
llvm.LLVMSetTarget(mod, b'amdgcn-amd-amdhsa')
asm_bytes = asm_text.encode()
llvm.LLVMSetModuleInlineAsm2(mod, asm_bytes, len(asm_bytes))
buf = llvm.LLVMMemoryBufferRef()
_expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(tm, mod, llvm.LLVMObjectFile, err:=_cerr(), ctypes.pointer(buf)), err)
obj = ctypes.string_at(llvm.LLVMGetBufferStart(buf), llvm.LLVMGetBufferSize(buf))
llvm.LLVMDisposeMemoryBuffer(buf)
llvm.LLVMDisposeModule(mod)
return obj
finally:
llvm.LLVMContextDispose(ctx)
llvm.LLVMDisposeTargetMachine(tm)
def _extract_text(obj:bytes) -> bytes:
"""Extract .text section from ELF object bytes."""
return next(s.content for s in elf_loader(obj)[1] if s.name == ".text")
def llvm_assemble(instrs:list[str], mcpu:str, mattr:str) -> list[bytes]:
"""Assemble instructions in one LLVM emission, return per-instruction bytes."""
if not instrs: return []
parts = []
for instr in instrs:
parts.append(instr)
parts.append(_SENTINEL_ASM)
text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr))
results, start = [], 0
for _ in instrs:
idx = text.find(_SENTINEL, start)
assert idx != -1, "sentinel not found in .text section"
results.append(bytes(text[start:idx]))
start = idx + len(_SENTINEL)
return results
def llvm_disasm(code:bytes, mcpu:str, mattr:str) -> list[str]:
"""Disassemble raw bytes into instruction strings using LLVM."""
_init_llvm()
dc = llvm.LLVMCreateDisasmCPUFeatures(b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(), None, 0,
llvm.LLVMOpInfoCallback(0), llvm.LLVMSymbolLookupCallback(0))
if not dc: raise RuntimeError(f"failed to create disasm context for {mcpu}")
llvm.LLVMSetDisasmOptions(dc, 2 | 4) # PrintImmHex | AsmPrinterVariant
try:
buf = ctypes.create_string_buffer(256)
arr = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
results, offset = [], 0
while offset < len(code):
size = llvm.LLVMDisasmInstruction(dc, ctypes.cast(ctypes.addressof(arr) + offset, ctypes.POINTER(ctypes.c_uint8)),
len(code) - offset, 0, buf, 256)
if size == 0: break
results.append(buf.value.decode().strip())
offset += size
return results
finally:
llvm.LLVMDisasmDispose(dc)
def llvm_filter_valid_asm(tests:list[tuple[str, bytes]], mcpu:str, mattr:str) -> list[tuple[str, bytes]]:
"""Filter out tests where original ASM isn't valid on target, and where LLVM roundtrip doesn't match."""
if not tests: return []
# Assemble all instructions at once with sentinels and diagnostic handler to detect failures
parts, diag_errors = [], [] # type: ignore[var-annotated]
for asm, _ in tests:
parts.append(asm)
parts.append(_SENTINEL_ASM)
text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr, diag_errors))
results, start = [], 0
for _ in tests:
idx = text.find(_SENTINEL, start)
assert idx != -1, "sentinel not found in .text section"
results.append(bytes(text[start:idx]))
start = idx + len(_SENTINEL)
# Invalid instructions produce 0 bytes; also filter where LLVM roundtrip doesn't match original
return [(asm, data) for (asm, data), chunk in zip(tests, results) if len(chunk) > 0 and chunk == data]

View file

@ -1,55 +1,23 @@
#!/usr/bin/env python3
"""Integration test: round-trip RDNA3 assembly through AMD toolchain."""
import unittest, io, sys
"""Integration test: round-trip RDNA3 assembly through LLVM toolchain."""
import unittest
from tinygrad.runtime.autogen.amd.rdna3.ins import *
from test.amd.helpers import llvm_assemble, llvm_disasm
def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
def disassemble(lib: bytes, arch: str = "gfx1100") -> str:
"""Disassemble ELF binary using tinygrad's compiler, return raw output."""
from tinygrad.runtime.support.compiler_amd import HIPCompiler
old_stdout = sys.stdout
sys.stdout = io.StringIO()
HIPCompiler(arch).disassemble(lib)
output = sys.stdout.getvalue()
sys.stdout = old_stdout
return output
def parse_disassembly(raw: str) -> list[str]:
"""Parse disassembly output to list of instruction mnemonics."""
lines = []
for line in raw.splitlines():
if line.startswith('\t'):
instr = line.split('//')[0].strip()
if instr: lines.append(instr)
return lines
def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[str]:
"""Assemble instructions with our DSL, then disassemble with AMD toolchain."""
from tinygrad.runtime.support.compiler_amd import HIPCompiler
# Generate bytes from our DSL
def assemble_and_disassemble(instructions: list, mcpu: str = "gfx1100", mattr: str = "+real-true16,+wavefrontsize32") -> list[str]:
"""Assemble instructions with our DSL, then disassemble with LLVM."""
code_bytes = b''.join(inst.to_bytes() for inst in instructions)
# Wrap in minimal ELF-compatible assembly with .byte directives
byte_str = ', '.join(f'0x{b:02x}' for b in code_bytes)
asm_src = f".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n.byte {byte_str}\n"
# Assemble with AMD COMGR and disassemble
lib = HIPCompiler(arch).compile(asm_src)
return parse_disassembly(disassemble(lib, arch))
return llvm_disasm(code_bytes, mcpu, mattr)
class TestIntegration(unittest.TestCase):
"""Test our DSL output matches LLVM disassembly."""
def test_simple_sop1(self):
"""Test SOP1 instructions round-trip."""
instructions = [
s_mov_b32(s[0], s[1]),
s_mov_b32(s[2], 0),
s_not_b32(s[3], s[4]),
]
instructions = [s_mov_b32(s[0], s[1]), s_mov_b32(s[2], 0), s_not_b32(s[3], s[4])]
disasm = assemble_and_disassemble(instructions)
self.assertIn('s_mov_b32', disasm[0])
self.assertIn('s_mov_b32', disasm[1])
@ -57,11 +25,7 @@ class TestIntegration(unittest.TestCase):
def test_simple_sop2(self):
"""Test SOP2 instructions round-trip."""
instructions = [
s_add_u32(s[0], s[1], s[2]),
s_sub_u32(s[3], s[4], 10),
s_and_b32(s[5], s[6], s[7]),
]
instructions = [s_add_u32(s[0], s[1], s[2]), s_sub_u32(s[3], s[4], 10), s_and_b32(s[5], s[6], s[7])]
disasm = assemble_and_disassemble(instructions)
self.assertIn('s_add_u32', disasm[0])
self.assertIn('s_sub_u32', disasm[1])
@ -69,33 +33,22 @@ class TestIntegration(unittest.TestCase):
def test_simple_vop2(self):
"""Test VOP2 instructions round-trip."""
instructions = [
v_add_f32_e32(v[0], v[1], v[2]),
v_mul_f32_e32(v[3], 1.0, v[4]), # 1.0 is inline constant
v_and_b32_e32(v[5], 10, v[6]), # small inline constant
]
instructions = [v_add_f32_e32(v[0], v[1], v[2]), v_mul_f32_e32(v[3], 1.0, v[4]), v_and_b32_e32(v[5], 10, v[6])]
disasm = assemble_and_disassemble(instructions)
self.assertIn('v_add_f32', disasm[0])
self.assertIn('v_mul_f32', disasm[1])
def test_control_flow(self):
"""Test control flow instructions."""
instructions = [
s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
s_endpgm(),
]
instructions = [s_waitcnt(simm16=waitcnt(lgkmcnt=0)), s_endpgm()]
disasm = assemble_and_disassemble(instructions)
self.assertIn('s_waitcnt', disasm[0])
self.assertIn('s_endpgm', disasm[1])
def test_memory_ops(self):
"""Test memory instructions."""
instructions = [
s_load_b32(s[0], s[0:1], NULL),
s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
s_endpgm(),
]
instructions = [s_load_b32(s[0], s[0:1], NULL), s_waitcnt(simm16=waitcnt(lgkmcnt=0)), global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
s_endpgm()]
disasm = assemble_and_disassemble(instructions)
self.assertIn('s_load_b32', disasm[0])
self.assertIn('s_waitcnt', disasm[1])
@ -103,156 +56,62 @@ class TestIntegration(unittest.TestCase):
def test_full_kernel(self):
"""Test a complete kernel similar to tinygrad output."""
# Simple kernel: load value, add 1, store back
instructions = [
# Get thread ID
v_mov_b32_e32(v[0], s[0]), # base addr low
v_mov_b32_e32(v[1], s[1]), # base addr high
# Load value
global_load_b32(vdst=v[2], addr=v[0:1], saddr=OFF),
s_waitcnt(simm16=waitcnt(vmcnt=0)),
# Add 1.0
v_add_f32_e32(v[2], 1.0, v[2]),
# Store result
global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
s_endpgm(),
]
instructions = [v_mov_b32_e32(v[0], s[0]), v_mov_b32_e32(v[1], s[1]), global_load_b32(vdst=v[2], addr=v[0:1], saddr=OFF),
s_waitcnt(simm16=waitcnt(vmcnt=0)), v_add_f32_e32(v[2], 1.0, v[2]), global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
s_endpgm()]
disasm = assemble_and_disassemble(instructions)
# Verify key instructions are present
self.assertTrue(any('global_load' in d for d in disasm))
self.assertTrue(any('v_add_f32' in d for d in disasm))
self.assertTrue(any('global_store' in d for d in disasm))
self.assertTrue(any('s_endpgm' in d for d in disasm))
def test_bytes_roundtrip(self):
"""Test that our bytes match what AMD assembler produces."""
from tinygrad.runtime.support.compiler_amd import HIPCompiler
# Simple instruction
"""Test that our bytes match what LLVM assembler produces."""
inst = s_mov_b32(s[0], s[1])
our_bytes = inst.to_bytes()
# Assemble same instruction with AMD toolchain
asm_src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\ns_mov_b32 s0, s1\n"
compiler = HIPCompiler("gfx1100")
lib = compiler.compile(asm_src)
raw = disassemble(lib)
for line in raw.splitlines():
if 's_mov_b32' in line and '//' in line:
# Extract hex bytes from comment: "// 000000001300: BE800001"
comment = line.split('//')[1].strip()
hex_str = comment.split(':')[1].strip()
# Convert big-endian hex string to little-endian bytes
amd_bytes = bytes.fromhex(hex_str)[::-1] # reverse for little-endian
self.assertEqual(our_bytes, amd_bytes, f"Bytes mismatch: ours={our_bytes.hex()} AMD={amd_bytes.hex()}")
return
self.fail("Could not find s_mov_b32 in disassembly")
llvm_bytes = llvm_assemble(["s_mov_b32 s0, s1"], "gfx1100", "+real-true16,+wavefrontsize32")[0]
self.assertEqual(our_bytes, llvm_bytes, f"Bytes mismatch: ours={our_bytes.hex()} LLVM={llvm_bytes.hex()}")
class TestTinygradIntegration(unittest.TestCase):
"""Test that we can parse disassembled tinygrad kernels."""
"""Test that we can parse tinygrad kernel disassembly."""
def _get_kernel_code(self, op_fn) -> bytes:
from tinygrad import Tensor
from tinygrad.codegen import get_program
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.runtime.support.compiler_amd import AMDLLVMCompiler
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.uop.ops import Ops
result = op_fn(Tensor)
schedule = result.schedule()
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
assert len(sink_items) > 0, "No SINK in schedule"
renderer = AMDLLVMRenderer('gfx1100')
prg = get_program(sink_items[0].ast, renderer)
lib = AMDLLVMCompiler('gfx1100').compile(prg.src)
return next(s.content for s in elf_loader(lib)[1] if s.name == ".text")
def test_simple_add_kernel(self):
"""Generate a simple add kernel from tinygrad and verify disassembly."""
from tinygrad import Tensor
from tinygrad.codegen import get_program
from tinygrad.renderer.cstyle import AMDHIPRenderer
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.uop.ops import Ops
# Create a computation that generates a real kernel
a = Tensor([1.0, 2.0, 3.0, 4.0]).realize()
b = Tensor([5.0, 6.0, 7.0, 8.0]).realize()
c = a + b
# Get schedule and find SINK
schedule = c.schedule()
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
self.assertTrue(len(sink_items) > 0, "No SINK in schedule")
# Generate program
renderer = AMDHIPRenderer('gfx1100')
prg = get_program(sink_items[0].ast, renderer)
self.assertIsNotNone(prg.src)
# Compile and disassemble
compiler = HIPCompiler('gfx1100')
lib = compiler.compile(prg.src)
raw_disasm = disassemble(lib)
instrs = parse_disassembly(raw_disasm)
# Verify we got some instructions
code = self._get_kernel_code(lambda T: T([1.0, 2.0, 3.0, 4.0]).realize() + T([5.0, 6.0, 7.0, 8.0]).realize())
instrs = llvm_disasm(code, "gfx1100", "+real-true16,+wavefrontsize32")
self.assertTrue(len(instrs) > 0, "No instructions in disassembly")
# Should have an endpgm
self.assertTrue(any('s_endpgm' in i for i in instrs), "Missing s_endpgm")
def test_matmul_kernel(self):
"""Generate a matmul kernel and verify disassembly has expected patterns."""
from tinygrad import Tensor
from tinygrad.codegen import get_program
from tinygrad.renderer.cstyle import AMDHIPRenderer
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.uop.ops import Ops
# Create a small matmul
a = Tensor.rand(4, 4).realize()
b = Tensor.rand(4, 4).realize()
c = a @ b
# Get schedule
schedule = c.schedule()
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
self.assertTrue(len(sink_items) > 0)
# Generate and compile
renderer = AMDHIPRenderer('gfx1100')
prg = get_program(sink_items[0].ast, renderer)
compiler = HIPCompiler('gfx1100')
lib = compiler.compile(prg.src)
raw_disasm = disassemble(lib)
instrs = parse_disassembly(raw_disasm)
# Matmul should have multiply and add instructions
code = self._get_kernel_code(lambda T: T.rand(4, 4).realize() @ T.rand(4, 4).realize())
instrs = llvm_disasm(code, "gfx1100", "+real-true16,+wavefrontsize32")
has_mul = any('mul' in i.lower() for i in instrs)
has_add = any('add' in i.lower() for i in instrs)
self.assertTrue(has_mul or has_add, "Matmul should have mul/add ops")
def test_disasm_to_bytes_roundtrip(self):
"""Parse disassembled instructions and verify we can re-encode some of them."""
from tinygrad import Tensor
from tinygrad.codegen import get_program
from tinygrad.renderer.cstyle import AMDHIPRenderer
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.uop.ops import Ops
# Simple kernel
a = Tensor([1.0, 2.0, 3.0, 4.0]).realize()
b = (a * 2.0)
schedule = b.schedule()
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
if not sink_items: return # skip if no kernel
renderer = AMDHIPRenderer('gfx1100')
prg = get_program(sink_items[0].ast, renderer)
compiler = HIPCompiler('gfx1100')
lib = compiler.compile(prg.src)
raw_disasm = disassemble(lib)
# Find s_endpgm and verify we can encode it
for line in raw_disasm.splitlines():
if 's_endpgm' in line and '//' in line:
# Extract bytes from comment
comment = line.split('//')[1].strip()
hex_str = comment.split(':')[1].strip()
amd_bytes = bytes.fromhex(hex_str)[::-1]
# Our encoding
our_inst = s_endpgm()
our_bytes = our_inst.to_bytes()
self.assertEqual(our_bytes, amd_bytes, f"s_endpgm mismatch: ours={our_bytes.hex()} AMD={amd_bytes.hex()}")
return
"""Verify s_endpgm encoding matches between our DSL and LLVM."""
our_bytes = s_endpgm().to_bytes()
llvm_bytes = llvm_assemble(["s_endpgm"], "gfx1100", "+real-true16,+wavefrontsize32")[0]
self.assertEqual(our_bytes, llvm_bytes, f"s_endpgm mismatch: ours={our_bytes.hex()} LLVM={llvm_bytes.hex()}")
if __name__ == "__main__":
unittest.main()

View file

@ -8,11 +8,11 @@ Only compute-relevant instruction formats are tested. Graphics-only formats not
- VIMAGE/VSAMPLE: image sampling instructions (RDNA4)
- VBUFFER: buffer instructions (RDNA4)
"""
import unittest, re, subprocess, functools
import unittest, re, functools
from tinygrad.helpers import fetch
from test.amd.disasm import disasm
from tinygrad.renderer.amd import decode_inst, detect_format
from test.amd.helpers import get_llvm_mc, get_target, get_mattr
from test.amd.helpers import llvm_assemble, llvm_filter_valid_asm, get_target, get_mattr
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/llvmorg-21.1.0/llvm/test/MC/AMDGPU"
@ -74,42 +74,13 @@ def _get_tests_uncached(f: str, arch: str) -> list[tuple[str, bytes]]:
# Exclude v_interp_* (graphics-only, not on CDNA)
if arch == "cdna": tests = [(asm, data) for asm, data in tests if not asm.startswith('v_interp_')]
# Filter out tests where original ASM isn't valid on target (e.g., gfx9 tests with gfx942/gfx950 constraints)
if arch == "cdna" and not ('gfx942' in f or 'gfx950' in f or 'gfx90a' in f): tests = _filter_valid_asm(tests, arch)
if arch == "cdna" and not ('gfx942' in f or 'gfx950' in f or 'gfx90a' in f):
tests = llvm_filter_valid_asm(tests, get_target(arch), get_mattr(arch))
return tests
@functools.cache
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]: return _get_tests_uncached(f, arch)
def _compile_asm_batch(instrs: list[str], arch: str = "rdna3", mcpu: str|None = None) -> list[bytes]:
if not instrs: return []
mcpu, mattr = mcpu or get_target(arch), get_mattr(arch)
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
for line in result.stdout.split('\n') if 'encoding:' in line]
def _filter_valid_asm(tests: list[tuple[str, bytes]], arch: str) -> list[tuple[str, bytes]]:
"""Filter out tests where the original ASM isn't valid on the target (e.g., gfx9 tests with gfx942/gfx950 constraints)."""
if not tests: return []
mcpu = get_target(arch)
# Batch assemble all instructions, parse stderr to find which lines failed
instrs = [asm for asm, _ in tests]
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
# Parse error lines from stderr (format: "<stdin>:N:..." where N is 1-indexed, line 1 is ".text")
failed_lines = set()
for line in result.stderr.split('\n'):
if m := re.match(r'<stdin>:(\d+):', line): failed_lines.add(int(m.group(1)) - 1) # -1 for .text, so line 2 -> index 1 -> tests[0]
# Also filter out tests where LLVM roundtrip doesn't match original (reserved bits set in original)
valid = [(asm, data) for i, (asm, data) in enumerate(tests) if (i + 1) not in failed_lines]
if not valid: return []
llvm_result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-show-encoding'],
input=".text\n" + "\n".join(asm for asm, _ in valid) + "\n", capture_output=True, text=True, timeout=30)
llvm_bytes = [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
for line in llvm_result.stdout.split('\n') if 'encoding:' in line]
return [(asm, data) for (asm, data), lb in zip(valid, llvm_bytes) if lb == data]
def _make_test(f: str, arch: str, test_type: str):
def test(self):
tests = _get_tests(f, arch)
@ -160,7 +131,7 @@ def _make_test(f: str, arch: str, test_type: str):
print(f"{name}: {len(to_test)} passed, {skipped} skipped")
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")
# Compare disasm->reassemble with original encoding (filter reserved bit cases where LLVM can't reproduce)
llvm_bytes = _compile_asm_batch([t[1] for t in to_test], arch, mcpu)
llvm_bytes = llvm_assemble([t[1] for t in to_test], mcpu, get_mattr(arch))
valid = [(enc, d, llvm) for (enc, d), llvm in zip(to_test, llvm_bytes) if llvm == enc]
print(f"{name}: {len(valid)}/{len(to_test)} matched LLVM encoding")
for enc, _, llvm in valid: self.assertEqual(llvm, enc)

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""Test that invalid instructions raise exceptions through the mock GPU stack."""
import unittest, subprocess, os, time
import unittest, subprocess, os, sys, time
class TestMockGPUInvalidInstruction(unittest.TestCase):
def test_unsupported_instruction_raises(self):
@ -43,7 +43,7 @@ dev.synchronize()
env["HCQDEV_WAIT_TIMEOUT_MS"] = "10000"
st = time.perf_counter()
result = subprocess.run(["python", "-c", test_code], env=env, capture_output=True, text=True, timeout=60)
result = subprocess.run([sys.executable, "-c", test_code], env=env, capture_output=True, text=True, timeout=60)
elapsed = time.perf_counter() - st
self.assertNotEqual(result.returncode, 0, "should have raised")

View file

@ -1,27 +1,14 @@
#!/usr/bin/env python3
import unittest, subprocess
import unittest
from tinygrad.runtime.autogen.amd.rdna3.ins import *
from test.amd.helpers import get_llvm_mc
from test.amd.helpers import llvm_assemble
from test.amd.disasm import disasm
def llvm_assemble(asm: str) -> bytes:
"""Assemble using llvm-mc and return bytes."""
result = subprocess.run(
[get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
input=asm, capture_output=True, text=True
)
out = b''
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
enc = enc.strip('[]').replace('0x', '').replace(',', '')
out += bytes.fromhex(enc)
if not out: raise ValueError(f"no encoding found: {result.stdout} {result.stderr}")
return out
def _asm(asm: str) -> bytes: return llvm_assemble([asm], 'gfx1100', '+real-true16,+wavefrontsize32')[0]
class TestRDNA3Asm(unittest.TestCase):
def test_full_program(self):
"""Test the full program from rdna3fun.py matches llvm-mc output."""
"""Test the full program from rdna3fun.py matches LLVM output."""
program = [
v_bfe_u32(v[1], v[0], 10, 10),
s_load_b128(s[4:7], s[0:1], NULL),
@ -45,52 +32,35 @@ class TestRDNA3Asm(unittest.TestCase):
s_endpgm(),
]
asm = """
v_bfe_u32 v1, v0, 10, 10
s_load_b128 s[4:7], s[0:1], null
v_and_b32_e32 v0, 0x3FF, v0
s_mulk_i32 s3, 0x87
v_mad_u64_u32 v[1:2], null, s2, 3, v[1:2]
v_mul_u32_u24_e32 v0, 45, v0
v_ashrrev_i32_e32 v2, 31, v1
v_add3_u32 v0, v0, s3, v1
v_lshlrev_b64 v[2:3], 2, v[1:2]
v_ashrrev_i32_e32 v1, 31, v0
v_lshlrev_b64 v[0:1], 2, v[0:1]
s_waitcnt lgkmcnt(0)
v_add_co_u32 v2, vcc_lo, s6, v2
v_add_co_ci_u32_e32 v3, vcc_lo, s7, v3, vcc_lo
v_add_co_u32 v0, vcc_lo, s4, v0
global_load_b32 v2, v[2:3], off
v_add_co_ci_u32_e32 v1, vcc_lo, s5, v1, vcc_lo
s_waitcnt vmcnt(0)
global_store_b32 v[0:1], v2, off
s_endpgm
"""
expected = llvm_assemble(asm)
for inst,rt in zip(program, asm.strip().split("\n")): print(f"{disasm(inst):50s} {rt}")
actual = b''.join(inst.to_bytes() for inst in program)
self.assertEqual(actual, expected)
asm_lines = [
"v_bfe_u32 v1, v0, 10, 10", "s_load_b128 s[4:7], s[0:1], null", "v_and_b32_e32 v0, 0x3FF, v0",
"s_mulk_i32 s3, 0x87", "v_mad_u64_u32 v[1:2], null, s2, 3, v[1:2]", "v_mul_u32_u24_e32 v0, 45, v0",
"v_ashrrev_i32_e32 v2, 31, v1", "v_add3_u32 v0, v0, s3, v1", "v_lshlrev_b64 v[2:3], 2, v[1:2]",
"v_ashrrev_i32_e32 v1, 31, v0", "v_lshlrev_b64 v[0:1], 2, v[0:1]", "s_waitcnt lgkmcnt(0)",
"v_add_co_u32 v2, vcc_lo, s6, v2", "v_add_co_ci_u32_e32 v3, vcc_lo, s7, v3, vcc_lo",
"v_add_co_u32 v0, vcc_lo, s4, v0", "global_load_b32 v2, v[2:3], off",
"v_add_co_ci_u32_e32 v1, vcc_lo, s5, v1, vcc_lo", "s_waitcnt vmcnt(0)",
"global_store_b32 v[0:1], v2, off", "s_endpgm",
]
expected = llvm_assemble(asm_lines, 'gfx1100', '+real-true16,+wavefrontsize32')
for inst, rt in zip(program, asm_lines): print(f"{disasm(inst):50s} {rt}")
for inst, exp in zip(program, expected): self.assertEqual(inst.to_bytes(), exp)
def test_sop2_s_add_u32(self):
inst = SOP2(SOP2Op.S_ADD_U32, s[3], s[0], s[1])
expected = llvm_assemble("s_add_u32 s3, s0, s1")
self.assertEqual(inst.to_bytes(), expected)
self.assertEqual(inst.to_bytes(), _asm("s_add_u32 s3, s0, s1"))
def test_vop2_v_and_b32_inline_const(self):
inst = v_and_b32_e32(v[0], 10, v[0])
expected = llvm_assemble("v_and_b32_e32 v0, 10, v0")
self.assertEqual(inst.to_bytes(), expected)
self.assertEqual(inst.to_bytes(), _asm("v_and_b32_e32 v0, 10, v0"))
def test_sopp_s_endpgm(self):
inst = s_endpgm()
expected = llvm_assemble("s_endpgm")
self.assertEqual(inst.to_bytes(), expected)
self.assertEqual(inst.to_bytes(), _asm("s_endpgm"))
def test_sop1_s_mov_b32(self):
inst = s_mov_b32(s[0], s[1])
expected = llvm_assemble("s_mov_b32 s0, s1")
self.assertEqual(inst.to_bytes(), expected)
self.assertEqual(inst.to_bytes(), _asm("s_mov_b32 s0, s1"))
if __name__ == "__main__":
unittest.main()

View file

@ -1,9 +1,9 @@
#!/usr/bin/env python3
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
import unittest, io, sys, re, subprocess, os
import unittest, io, sys, re
from tinygrad import Device
from tinygrad.renderer.amd import detect_format
from test.amd.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr
from test.amd.helpers import llvm_assemble, llvm_disasm, get_target, get_mattr
from test.amd.disasm import disasm
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
@ -31,45 +31,18 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
def compile_asm(instr: str, arch: str = 'rdna3') -> bytes:
"""Compile a single instruction using LLVM."""
return compile_asm_batch([instr], arch)[0]
return llvm_assemble([instr], get_target(arch), get_mattr(arch))[0]
def compile_asm_batch(instrs: list[str], arch: str = 'rdna3') -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
if not instrs: return []
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={get_target(arch)}', f'-mattr={get_mattr(arch)}', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
encodings = []
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
return encodings
"""Compile multiple instructions with a single LLVM emission."""
return llvm_assemble(instrs, get_target(arch), get_mattr(arch))
def compile_and_disasm_batch(instrs: list[str], arch: str = 'rdna3') -> list[str]:
"""Compile instructions with LLVM and get LLVM's disassembly."""
import tempfile
if not instrs: return []
mcpu, mattr = get_target(arch), get_mattr(arch)
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
obj_path = f.name
try:
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-filetype=obj', '-o', obj_path],
input=src, capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
results: list[str] = []
for line in result.stdout.splitlines():
if '//' not in line: continue
instr = line.split('//')[0].strip()
if instr: results.append(instr)
return results[:len(instrs)]
finally:
os.unlink(obj_path)
code = b''.join(llvm_assemble(instrs, mcpu, mattr))
return llvm_disasm(code, mcpu, mattr)[:len(instrs)]
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
class TestTinygradKernelRoundtrip(unittest.TestCase):
@ -174,12 +147,12 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
if our_disasm is None:
disasm_skipped += 1
elif idx in disasm_llvm_map:
llvm_disasm = disasm_llvm_map[idx]
if our_disasm == llvm_disasm:
llvm_disasm_str = disasm_llvm_map[idx]
if our_disasm == llvm_disasm_str:
disasm_passed += 1
else:
disasm_failed += 1
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm_str}'")
else:
disasm_skipped += 1

View file

@ -89,7 +89,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
except Exception as e: exc = e
(t:=threading.Thread(target=worker, daemon=True)).start()
t.join(timeout=1)
t.join(timeout=5)
if exc is not None: raise exc
if t.is_alive(): raise RuntimeError("rocprof decoder timeout")
return occupancy_records, wave_insts

View file

@ -80,6 +80,7 @@ def extract_packet_encodings():
def extract_cdna_packet_sizes():
"""Extract CDNA pkt_fmt -> size mapping by running rocprof decoder to populate its hash table."""
if not _load_lib(): return None
from test.amd.test_sqtt_examples import run_rocprof_decoder
if not (pkl_path := next((EXAMPLES_DIR / "gfx950").glob("*.pkl"), None)): return None
@ -119,8 +120,7 @@ class TestSQTTMatchesBinary(unittest.TestCase):
def test_cdna_packet_sizes(self):
"""Extract and verify CDNA pkt_fmt -> size mapping from rocprof's hash table."""
if not (EXAMPLES_DIR / "gfx950").exists(): self.skipTest("no CDNA examples")
pkt_sizes = extract_cdna_packet_sizes()
self.assertIsNotNone(pkt_sizes, "failed to extract CDNA packet sizes")
if not (pkt_sizes := extract_cdna_packet_sizes()): self.skipTest("rocprof-trace-decoder not installed")
for pkt_fmt, size in CDNA_PKT_SIZES.items():
with self.subTest(pkt_fmt=pkt_fmt): self.assertEqual(pkt_sizes.get(pkt_fmt), size)

View file

@ -9,37 +9,47 @@ from __future__ import annotations
import ctypes, functools, re, platform, subprocess, tempfile
from typing import Any, Callable
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) in MXCSR to match RDNA3 default float mode
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode
# x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24)
# Only applied during emulator execution, restored afterward to avoid breaking hypothesis tests
@functools.cache
def _get_mxcsr_lib():
if platform.machine() not in ('x86_64', 'AMD64'): return None
try:
def _get_ftz_lib():
machine = platform.machine()
if machine in ('x86_64', 'AMD64'):
src = b'''
unsigned int get_mxcsr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;}
void set_mxcsr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));}
unsigned int get_fpcr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;}
void set_fpcr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));}
'''
ftz_bits = 0x8040 # DAZ (bit 6) + FTZ (bit 15)
elif machine in ('arm64', 'aarch64'):
src = b'''
unsigned int get_fpcr(void){unsigned long long v;__asm__ __volatile__("mrs %0,fpcr":"=r"(v));return(unsigned int)v;}
void set_fpcr(unsigned int m){unsigned long long v=m;__asm__ __volatile__("msr fpcr,%0"::"r"(v));}
'''
ftz_bits = 1 << 24 # FZ (bit 24)
else: return None, 0
try:
with tempfile.NamedTemporaryFile(suffix='.so', delete=False) as f:
subprocess.check_output(['clang', '-shared', '-O2', '-x', 'c', '-', '-o', f.name], input=src)
lib = ctypes.CDLL(f.name)
lib.get_mxcsr.restype = ctypes.c_uint32
lib.set_mxcsr.argtypes = [ctypes.c_uint32]
return lib
except Exception: return None
lib.get_fpcr.restype = ctypes.c_uint32
lib.set_fpcr.argtypes = [ctypes.c_uint32]
return lib, ftz_bits
except Exception: return None, 0
class _MXCSRContext:
"""Context manager to set DAZ+FTZ during emulator execution and restore afterward."""
__slots__ = ('_saved',)
def __enter__(self):
lib = _get_mxcsr_lib()
lib, ftz_bits = _get_ftz_lib()
if lib is None: return self
self._saved = lib.get_mxcsr()
lib.set_mxcsr(self._saved | 0x8040) # DAZ (bit 6) + FTZ (bit 15)
self._saved = lib.get_fpcr()
lib.set_fpcr(self._saved | ftz_bits)
return self
def __exit__(self, *args):
lib = _get_mxcsr_lib()
lib, _ = _get_ftz_lib()
if lib is None or not hasattr(self, '_saved'): return
lib.set_mxcsr(self._saved)
lib.set_fpcr(self._saved)
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.dtype import dtypes