mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Revert "hotfix: skip test/amd in macpytest" (#14704)
* Revert "hotfix: skip test/amd in macpytest"
This reverts commit b7dade2adf.
* no llvm subprocess
* simpler
* sys.exec
* cleanup
* process safe
* diag
* arm ftz support
* 5 sec
* this one
This commit is contained in:
parent
d4bc5ab609
commit
d3adb8428e
10 changed files with 215 additions and 339 deletions
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
|
|
@ -56,7 +56,7 @@ jobs:
|
|||
- name: Run pytest -nauto
|
||||
run: |
|
||||
source /tmp/tinygrad_pytest_ci/bin/activate
|
||||
pytest -nauto --ignore test/amd/ --durations=20
|
||||
pytest -nauto --durations=20
|
||||
|
||||
testmacbenchmark:
|
||||
name: Mac Benchmark
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
"""Shared test helpers for RDNA3 tests."""
|
||||
import shutil
|
||||
"""Shared test helpers for AMD tests."""
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.runtime.autogen import llvm
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
@dataclass
|
||||
class KernelInfo:
|
||||
|
|
@ -11,19 +14,6 @@ class KernelInfo:
|
|||
buf_idxs: list[int] # indices into shared buffer pool
|
||||
buf_sizes: list[int] # sizes for each buffer index
|
||||
|
||||
# LLVM tool detection (shared across test files)
|
||||
def get_llvm_mc():
|
||||
"""Find llvm-mc executable, preferring newer versions."""
|
||||
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-mc not found")
|
||||
|
||||
def get_llvm_objdump():
|
||||
"""Find llvm-objdump executable, preferring newer versions."""
|
||||
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-objdump not found")
|
||||
|
||||
ARCH_TO_TARGET:dict[str, list[str]] = {
|
||||
"rdna3":["gfx1100"],
|
||||
"rdna4":["gfx1200"],
|
||||
|
|
@ -35,4 +25,107 @@ TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items
|
|||
def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
|
||||
|
||||
def get_mattr(arch:str) -> str:
|
||||
return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch]
|
||||
return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch]
|
||||
|
||||
# LLVM in-process assembler/disassembler (replaces llvm-mc and llvm-objdump subprocesses)
|
||||
_SENTINEL = b'\xde\xad\xbe\xef'
|
||||
_SENTINEL_ASM = '.byte 0xde, 0xad, 0xbe, 0xef'
|
||||
|
||||
def _cerr(): return ctypes.pointer(ctypes.pointer(ctypes.c_char()))
|
||||
def _expect(x, err, ret=None):
|
||||
if x: raise RuntimeError(unwrap(ctypes.cast(err.contents, ctypes.c_char_p).value).decode() if not isinstance(err, str) else err)
|
||||
return ret
|
||||
|
||||
def _init_llvm():
|
||||
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter', 'Disassembler']:
|
||||
getattr(llvm, f'LLVMInitializeAMDGPU{component}')()
|
||||
|
||||
def _create_target_machine(mcpu:str, mattr:str) -> llvm.LLVMTargetMachineRef:
|
||||
target = _expect(llvm.LLVMGetTargetFromTriple(b'amdgcn-amd-amdhsa', ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=_cerr()), err, tgt)
|
||||
return llvm.LLVMCreateTargetMachine(target, b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(),
|
||||
llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocDefault, llvm.LLVMCodeModelDefault)
|
||||
|
||||
def _emit_obj(asm_text:str, mcpu:str, mattr:str, diag_errors:list[str]|None=None) -> bytes:
|
||||
"""Assemble raw asm text into an ELF object using LLVM in-process."""
|
||||
_init_llvm()
|
||||
tm = _create_target_machine(mcpu, mattr)
|
||||
ctx = llvm.LLVMContextCreate()
|
||||
try:
|
||||
errors = diag_errors if diag_errors is not None else []
|
||||
@llvm.LLVMDiagnosticHandler
|
||||
def handle_diag(diag_ref, _arg):
|
||||
if llvm.LLVMGetDiagInfoSeverity(diag_ref) == llvm.LLVMDSError:
|
||||
errors.append(ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode())
|
||||
llvm.LLVMContextSetDiagnosticHandler(ctx, handle_diag, None)
|
||||
mod = llvm.LLVMModuleCreateWithNameInContext(b'asm', ctx)
|
||||
llvm.LLVMSetTarget(mod, b'amdgcn-amd-amdhsa')
|
||||
asm_bytes = asm_text.encode()
|
||||
llvm.LLVMSetModuleInlineAsm2(mod, asm_bytes, len(asm_bytes))
|
||||
buf = llvm.LLVMMemoryBufferRef()
|
||||
_expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(tm, mod, llvm.LLVMObjectFile, err:=_cerr(), ctypes.pointer(buf)), err)
|
||||
obj = ctypes.string_at(llvm.LLVMGetBufferStart(buf), llvm.LLVMGetBufferSize(buf))
|
||||
llvm.LLVMDisposeMemoryBuffer(buf)
|
||||
llvm.LLVMDisposeModule(mod)
|
||||
return obj
|
||||
finally:
|
||||
llvm.LLVMContextDispose(ctx)
|
||||
llvm.LLVMDisposeTargetMachine(tm)
|
||||
|
||||
def _extract_text(obj:bytes) -> bytes:
|
||||
"""Extract .text section from ELF object bytes."""
|
||||
return next(s.content for s in elf_loader(obj)[1] if s.name == ".text")
|
||||
|
||||
def llvm_assemble(instrs:list[str], mcpu:str, mattr:str) -> list[bytes]:
|
||||
"""Assemble instructions in one LLVM emission, return per-instruction bytes."""
|
||||
if not instrs: return []
|
||||
parts = []
|
||||
for instr in instrs:
|
||||
parts.append(instr)
|
||||
parts.append(_SENTINEL_ASM)
|
||||
text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr))
|
||||
results, start = [], 0
|
||||
for _ in instrs:
|
||||
idx = text.find(_SENTINEL, start)
|
||||
assert idx != -1, "sentinel not found in .text section"
|
||||
results.append(bytes(text[start:idx]))
|
||||
start = idx + len(_SENTINEL)
|
||||
return results
|
||||
|
||||
def llvm_disasm(code:bytes, mcpu:str, mattr:str) -> list[str]:
|
||||
"""Disassemble raw bytes into instruction strings using LLVM."""
|
||||
_init_llvm()
|
||||
dc = llvm.LLVMCreateDisasmCPUFeatures(b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(), None, 0,
|
||||
llvm.LLVMOpInfoCallback(0), llvm.LLVMSymbolLookupCallback(0))
|
||||
if not dc: raise RuntimeError(f"failed to create disasm context for {mcpu}")
|
||||
llvm.LLVMSetDisasmOptions(dc, 2 | 4) # PrintImmHex | AsmPrinterVariant
|
||||
try:
|
||||
buf = ctypes.create_string_buffer(256)
|
||||
arr = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
|
||||
results, offset = [], 0
|
||||
while offset < len(code):
|
||||
size = llvm.LLVMDisasmInstruction(dc, ctypes.cast(ctypes.addressof(arr) + offset, ctypes.POINTER(ctypes.c_uint8)),
|
||||
len(code) - offset, 0, buf, 256)
|
||||
if size == 0: break
|
||||
results.append(buf.value.decode().strip())
|
||||
offset += size
|
||||
return results
|
||||
finally:
|
||||
llvm.LLVMDisasmDispose(dc)
|
||||
|
||||
def llvm_filter_valid_asm(tests:list[tuple[str, bytes]], mcpu:str, mattr:str) -> list[tuple[str, bytes]]:
|
||||
"""Filter out tests where original ASM isn't valid on target, and where LLVM roundtrip doesn't match."""
|
||||
if not tests: return []
|
||||
# Assemble all instructions at once with sentinels and diagnostic handler to detect failures
|
||||
parts, diag_errors = [], [] # type: ignore[var-annotated]
|
||||
for asm, _ in tests:
|
||||
parts.append(asm)
|
||||
parts.append(_SENTINEL_ASM)
|
||||
text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr, diag_errors))
|
||||
results, start = [], 0
|
||||
for _ in tests:
|
||||
idx = text.find(_SENTINEL, start)
|
||||
assert idx != -1, "sentinel not found in .text section"
|
||||
results.append(bytes(text[start:idx]))
|
||||
start = idx + len(_SENTINEL)
|
||||
# Invalid instructions produce 0 bytes; also filter where LLVM roundtrip doesn't match original
|
||||
return [(asm, data) for (asm, data), chunk in zip(tests, results) if len(chunk) > 0 and chunk == data]
|
||||
|
|
|
|||
|
|
@ -1,55 +1,23 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Integration test: round-trip RDNA3 assembly through AMD toolchain."""
|
||||
import unittest, io, sys
|
||||
"""Integration test: round-trip RDNA3 assembly through LLVM toolchain."""
|
||||
import unittest
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
from test.amd.helpers import llvm_assemble, llvm_disasm
|
||||
|
||||
def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
|
||||
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
|
||||
|
||||
def disassemble(lib: bytes, arch: str = "gfx1100") -> str:
|
||||
"""Disassemble ELF binary using tinygrad's compiler, return raw output."""
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = io.StringIO()
|
||||
HIPCompiler(arch).disassemble(lib)
|
||||
output = sys.stdout.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
return output
|
||||
|
||||
def parse_disassembly(raw: str) -> list[str]:
|
||||
"""Parse disassembly output to list of instruction mnemonics."""
|
||||
lines = []
|
||||
for line in raw.splitlines():
|
||||
if line.startswith('\t'):
|
||||
instr = line.split('//')[0].strip()
|
||||
if instr: lines.append(instr)
|
||||
return lines
|
||||
|
||||
def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[str]:
|
||||
"""Assemble instructions with our DSL, then disassemble with AMD toolchain."""
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
# Generate bytes from our DSL
|
||||
def assemble_and_disassemble(instructions: list, mcpu: str = "gfx1100", mattr: str = "+real-true16,+wavefrontsize32") -> list[str]:
|
||||
"""Assemble instructions with our DSL, then disassemble with LLVM."""
|
||||
code_bytes = b''.join(inst.to_bytes() for inst in instructions)
|
||||
|
||||
# Wrap in minimal ELF-compatible assembly with .byte directives
|
||||
byte_str = ', '.join(f'0x{b:02x}' for b in code_bytes)
|
||||
asm_src = f".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n.byte {byte_str}\n"
|
||||
|
||||
# Assemble with AMD COMGR and disassemble
|
||||
lib = HIPCompiler(arch).compile(asm_src)
|
||||
return parse_disassembly(disassemble(lib, arch))
|
||||
return llvm_disasm(code_bytes, mcpu, mattr)
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Test our DSL output matches LLVM disassembly."""
|
||||
|
||||
def test_simple_sop1(self):
|
||||
"""Test SOP1 instructions round-trip."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], s[1]),
|
||||
s_mov_b32(s[2], 0),
|
||||
s_not_b32(s[3], s[4]),
|
||||
]
|
||||
instructions = [s_mov_b32(s[0], s[1]), s_mov_b32(s[2], 0), s_not_b32(s[3], s[4])]
|
||||
disasm = assemble_and_disassemble(instructions)
|
||||
self.assertIn('s_mov_b32', disasm[0])
|
||||
self.assertIn('s_mov_b32', disasm[1])
|
||||
|
|
@ -57,11 +25,7 @@ class TestIntegration(unittest.TestCase):
|
|||
|
||||
def test_simple_sop2(self):
|
||||
"""Test SOP2 instructions round-trip."""
|
||||
instructions = [
|
||||
s_add_u32(s[0], s[1], s[2]),
|
||||
s_sub_u32(s[3], s[4], 10),
|
||||
s_and_b32(s[5], s[6], s[7]),
|
||||
]
|
||||
instructions = [s_add_u32(s[0], s[1], s[2]), s_sub_u32(s[3], s[4], 10), s_and_b32(s[5], s[6], s[7])]
|
||||
disasm = assemble_and_disassemble(instructions)
|
||||
self.assertIn('s_add_u32', disasm[0])
|
||||
self.assertIn('s_sub_u32', disasm[1])
|
||||
|
|
@ -69,33 +33,22 @@ class TestIntegration(unittest.TestCase):
|
|||
|
||||
def test_simple_vop2(self):
|
||||
"""Test VOP2 instructions round-trip."""
|
||||
instructions = [
|
||||
v_add_f32_e32(v[0], v[1], v[2]),
|
||||
v_mul_f32_e32(v[3], 1.0, v[4]), # 1.0 is inline constant
|
||||
v_and_b32_e32(v[5], 10, v[6]), # small inline constant
|
||||
]
|
||||
instructions = [v_add_f32_e32(v[0], v[1], v[2]), v_mul_f32_e32(v[3], 1.0, v[4]), v_and_b32_e32(v[5], 10, v[6])]
|
||||
disasm = assemble_and_disassemble(instructions)
|
||||
self.assertIn('v_add_f32', disasm[0])
|
||||
self.assertIn('v_mul_f32', disasm[1])
|
||||
|
||||
def test_control_flow(self):
|
||||
"""Test control flow instructions."""
|
||||
instructions = [
|
||||
s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
|
||||
s_endpgm(),
|
||||
]
|
||||
instructions = [s_waitcnt(simm16=waitcnt(lgkmcnt=0)), s_endpgm()]
|
||||
disasm = assemble_and_disassemble(instructions)
|
||||
self.assertIn('s_waitcnt', disasm[0])
|
||||
self.assertIn('s_endpgm', disasm[1])
|
||||
|
||||
def test_memory_ops(self):
|
||||
"""Test memory instructions."""
|
||||
instructions = [
|
||||
s_load_b32(s[0], s[0:1], NULL),
|
||||
s_waitcnt(simm16=waitcnt(lgkmcnt=0)),
|
||||
global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
|
||||
s_endpgm(),
|
||||
]
|
||||
instructions = [s_load_b32(s[0], s[0:1], NULL), s_waitcnt(simm16=waitcnt(lgkmcnt=0)), global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
|
||||
s_endpgm()]
|
||||
disasm = assemble_and_disassemble(instructions)
|
||||
self.assertIn('s_load_b32', disasm[0])
|
||||
self.assertIn('s_waitcnt', disasm[1])
|
||||
|
|
@ -103,156 +56,62 @@ class TestIntegration(unittest.TestCase):
|
|||
|
||||
def test_full_kernel(self):
|
||||
"""Test a complete kernel similar to tinygrad output."""
|
||||
# Simple kernel: load value, add 1, store back
|
||||
instructions = [
|
||||
# Get thread ID
|
||||
v_mov_b32_e32(v[0], s[0]), # base addr low
|
||||
v_mov_b32_e32(v[1], s[1]), # base addr high
|
||||
# Load value
|
||||
global_load_b32(vdst=v[2], addr=v[0:1], saddr=OFF),
|
||||
s_waitcnt(simm16=waitcnt(vmcnt=0)),
|
||||
# Add 1.0
|
||||
v_add_f32_e32(v[2], 1.0, v[2]),
|
||||
# Store result
|
||||
global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
|
||||
s_endpgm(),
|
||||
]
|
||||
instructions = [v_mov_b32_e32(v[0], s[0]), v_mov_b32_e32(v[1], s[1]), global_load_b32(vdst=v[2], addr=v[0:1], saddr=OFF),
|
||||
s_waitcnt(simm16=waitcnt(vmcnt=0)), v_add_f32_e32(v[2], 1.0, v[2]), global_store_b32(addr=v[0:1], data=v[2], saddr=OFF),
|
||||
s_endpgm()]
|
||||
disasm = assemble_and_disassemble(instructions)
|
||||
# Verify key instructions are present
|
||||
self.assertTrue(any('global_load' in d for d in disasm))
|
||||
self.assertTrue(any('v_add_f32' in d for d in disasm))
|
||||
self.assertTrue(any('global_store' in d for d in disasm))
|
||||
self.assertTrue(any('s_endpgm' in d for d in disasm))
|
||||
|
||||
def test_bytes_roundtrip(self):
|
||||
"""Test that our bytes match what AMD assembler produces."""
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
# Simple instruction
|
||||
"""Test that our bytes match what LLVM assembler produces."""
|
||||
inst = s_mov_b32(s[0], s[1])
|
||||
our_bytes = inst.to_bytes()
|
||||
|
||||
# Assemble same instruction with AMD toolchain
|
||||
asm_src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\ns_mov_b32 s0, s1\n"
|
||||
compiler = HIPCompiler("gfx1100")
|
||||
lib = compiler.compile(asm_src)
|
||||
raw = disassemble(lib)
|
||||
|
||||
for line in raw.splitlines():
|
||||
if 's_mov_b32' in line and '//' in line:
|
||||
# Extract hex bytes from comment: "// 000000001300: BE800001"
|
||||
comment = line.split('//')[1].strip()
|
||||
hex_str = comment.split(':')[1].strip()
|
||||
# Convert big-endian hex string to little-endian bytes
|
||||
amd_bytes = bytes.fromhex(hex_str)[::-1] # reverse for little-endian
|
||||
self.assertEqual(our_bytes, amd_bytes, f"Bytes mismatch: ours={our_bytes.hex()} AMD={amd_bytes.hex()}")
|
||||
return
|
||||
self.fail("Could not find s_mov_b32 in disassembly")
|
||||
llvm_bytes = llvm_assemble(["s_mov_b32 s0, s1"], "gfx1100", "+real-true16,+wavefrontsize32")[0]
|
||||
self.assertEqual(our_bytes, llvm_bytes, f"Bytes mismatch: ours={our_bytes.hex()} LLVM={llvm_bytes.hex()}")
|
||||
|
||||
class TestTinygradIntegration(unittest.TestCase):
|
||||
"""Test that we can parse disassembled tinygrad kernels."""
|
||||
"""Test that we can parse tinygrad kernel disassembly."""
|
||||
|
||||
def _get_kernel_code(self, op_fn) -> bytes:
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen import get_program
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
from tinygrad.runtime.support.compiler_amd import AMDLLVMCompiler
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
result = op_fn(Tensor)
|
||||
schedule = result.schedule()
|
||||
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
|
||||
assert len(sink_items) > 0, "No SINK in schedule"
|
||||
renderer = AMDLLVMRenderer('gfx1100')
|
||||
prg = get_program(sink_items[0].ast, renderer)
|
||||
lib = AMDLLVMCompiler('gfx1100').compile(prg.src)
|
||||
return next(s.content for s in elf_loader(lib)[1] if s.name == ".text")
|
||||
|
||||
def test_simple_add_kernel(self):
|
||||
"""Generate a simple add kernel from tinygrad and verify disassembly."""
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen import get_program
|
||||
from tinygrad.renderer.cstyle import AMDHIPRenderer
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
# Create a computation that generates a real kernel
|
||||
a = Tensor([1.0, 2.0, 3.0, 4.0]).realize()
|
||||
b = Tensor([5.0, 6.0, 7.0, 8.0]).realize()
|
||||
c = a + b
|
||||
|
||||
# Get schedule and find SINK
|
||||
schedule = c.schedule()
|
||||
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
|
||||
self.assertTrue(len(sink_items) > 0, "No SINK in schedule")
|
||||
|
||||
# Generate program
|
||||
renderer = AMDHIPRenderer('gfx1100')
|
||||
prg = get_program(sink_items[0].ast, renderer)
|
||||
self.assertIsNotNone(prg.src)
|
||||
|
||||
# Compile and disassemble
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
lib = compiler.compile(prg.src)
|
||||
raw_disasm = disassemble(lib)
|
||||
instrs = parse_disassembly(raw_disasm)
|
||||
|
||||
# Verify we got some instructions
|
||||
code = self._get_kernel_code(lambda T: T([1.0, 2.0, 3.0, 4.0]).realize() + T([5.0, 6.0, 7.0, 8.0]).realize())
|
||||
instrs = llvm_disasm(code, "gfx1100", "+real-true16,+wavefrontsize32")
|
||||
self.assertTrue(len(instrs) > 0, "No instructions in disassembly")
|
||||
# Should have an endpgm
|
||||
self.assertTrue(any('s_endpgm' in i for i in instrs), "Missing s_endpgm")
|
||||
|
||||
def test_matmul_kernel(self):
|
||||
"""Generate a matmul kernel and verify disassembly has expected patterns."""
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen import get_program
|
||||
from tinygrad.renderer.cstyle import AMDHIPRenderer
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
# Create a small matmul
|
||||
a = Tensor.rand(4, 4).realize()
|
||||
b = Tensor.rand(4, 4).realize()
|
||||
c = a @ b
|
||||
|
||||
# Get schedule
|
||||
schedule = c.schedule()
|
||||
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
|
||||
self.assertTrue(len(sink_items) > 0)
|
||||
|
||||
# Generate and compile
|
||||
renderer = AMDHIPRenderer('gfx1100')
|
||||
prg = get_program(sink_items[0].ast, renderer)
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
lib = compiler.compile(prg.src)
|
||||
raw_disasm = disassemble(lib)
|
||||
instrs = parse_disassembly(raw_disasm)
|
||||
|
||||
# Matmul should have multiply and add instructions
|
||||
code = self._get_kernel_code(lambda T: T.rand(4, 4).realize() @ T.rand(4, 4).realize())
|
||||
instrs = llvm_disasm(code, "gfx1100", "+real-true16,+wavefrontsize32")
|
||||
has_mul = any('mul' in i.lower() for i in instrs)
|
||||
has_add = any('add' in i.lower() for i in instrs)
|
||||
self.assertTrue(has_mul or has_add, "Matmul should have mul/add ops")
|
||||
|
||||
def test_disasm_to_bytes_roundtrip(self):
|
||||
"""Parse disassembled instructions and verify we can re-encode some of them."""
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen import get_program
|
||||
from tinygrad.renderer.cstyle import AMDHIPRenderer
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
# Simple kernel
|
||||
a = Tensor([1.0, 2.0, 3.0, 4.0]).realize()
|
||||
b = (a * 2.0)
|
||||
|
||||
schedule = b.schedule()
|
||||
sink_items = [si for si in schedule if si.ast.op == Ops.SINK]
|
||||
if not sink_items: return # skip if no kernel
|
||||
|
||||
renderer = AMDHIPRenderer('gfx1100')
|
||||
prg = get_program(sink_items[0].ast, renderer)
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
lib = compiler.compile(prg.src)
|
||||
raw_disasm = disassemble(lib)
|
||||
|
||||
# Find s_endpgm and verify we can encode it
|
||||
for line in raw_disasm.splitlines():
|
||||
if 's_endpgm' in line and '//' in line:
|
||||
# Extract bytes from comment
|
||||
comment = line.split('//')[1].strip()
|
||||
hex_str = comment.split(':')[1].strip()
|
||||
amd_bytes = bytes.fromhex(hex_str)[::-1]
|
||||
|
||||
# Our encoding
|
||||
our_inst = s_endpgm()
|
||||
our_bytes = our_inst.to_bytes()
|
||||
|
||||
self.assertEqual(our_bytes, amd_bytes, f"s_endpgm mismatch: ours={our_bytes.hex()} AMD={amd_bytes.hex()}")
|
||||
return
|
||||
"""Verify s_endpgm encoding matches between our DSL and LLVM."""
|
||||
our_bytes = s_endpgm().to_bytes()
|
||||
llvm_bytes = llvm_assemble(["s_endpgm"], "gfx1100", "+real-true16,+wavefrontsize32")[0]
|
||||
self.assertEqual(our_bytes, llvm_bytes, f"s_endpgm mismatch: ours={our_bytes.hex()} LLVM={llvm_bytes.hex()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ Only compute-relevant instruction formats are tested. Graphics-only formats not
|
|||
- VIMAGE/VSAMPLE: image sampling instructions (RDNA4)
|
||||
- VBUFFER: buffer instructions (RDNA4)
|
||||
"""
|
||||
import unittest, re, subprocess, functools
|
||||
import unittest, re, functools
|
||||
from tinygrad.helpers import fetch
|
||||
from test.amd.disasm import disasm
|
||||
from tinygrad.renderer.amd import decode_inst, detect_format
|
||||
from test.amd.helpers import get_llvm_mc, get_target, get_mattr
|
||||
from test.amd.helpers import llvm_assemble, llvm_filter_valid_asm, get_target, get_mattr
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/llvmorg-21.1.0/llvm/test/MC/AMDGPU"
|
||||
|
||||
|
|
@ -74,42 +74,13 @@ def _get_tests_uncached(f: str, arch: str) -> list[tuple[str, bytes]]:
|
|||
# Exclude v_interp_* (graphics-only, not on CDNA)
|
||||
if arch == "cdna": tests = [(asm, data) for asm, data in tests if not asm.startswith('v_interp_')]
|
||||
# Filter out tests where original ASM isn't valid on target (e.g., gfx9 tests with gfx942/gfx950 constraints)
|
||||
if arch == "cdna" and not ('gfx942' in f or 'gfx950' in f or 'gfx90a' in f): tests = _filter_valid_asm(tests, arch)
|
||||
if arch == "cdna" and not ('gfx942' in f or 'gfx950' in f or 'gfx90a' in f):
|
||||
tests = llvm_filter_valid_asm(tests, get_target(arch), get_mattr(arch))
|
||||
return tests
|
||||
|
||||
@functools.cache
|
||||
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]: return _get_tests_uncached(f, arch)
|
||||
|
||||
def _compile_asm_batch(instrs: list[str], arch: str = "rdna3", mcpu: str|None = None) -> list[bytes]:
|
||||
if not instrs: return []
|
||||
mcpu, mattr = mcpu or get_target(arch), get_mattr(arch)
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
|
||||
for line in result.stdout.split('\n') if 'encoding:' in line]
|
||||
|
||||
def _filter_valid_asm(tests: list[tuple[str, bytes]], arch: str) -> list[tuple[str, bytes]]:
|
||||
"""Filter out tests where the original ASM isn't valid on the target (e.g., gfx9 tests with gfx942/gfx950 constraints)."""
|
||||
if not tests: return []
|
||||
mcpu = get_target(arch)
|
||||
# Batch assemble all instructions, parse stderr to find which lines failed
|
||||
instrs = [asm for asm, _ in tests]
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
|
||||
# Parse error lines from stderr (format: "<stdin>:N:..." where N is 1-indexed, line 1 is ".text")
|
||||
failed_lines = set()
|
||||
for line in result.stderr.split('\n'):
|
||||
if m := re.match(r'<stdin>:(\d+):', line): failed_lines.add(int(m.group(1)) - 1) # -1 for .text, so line 2 -> index 1 -> tests[0]
|
||||
# Also filter out tests where LLVM roundtrip doesn't match original (reserved bits set in original)
|
||||
valid = [(asm, data) for i, (asm, data) in enumerate(tests) if (i + 1) not in failed_lines]
|
||||
if not valid: return []
|
||||
llvm_result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(asm for asm, _ in valid) + "\n", capture_output=True, text=True, timeout=30)
|
||||
llvm_bytes = [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
|
||||
for line in llvm_result.stdout.split('\n') if 'encoding:' in line]
|
||||
return [(asm, data) for (asm, data), lb in zip(valid, llvm_bytes) if lb == data]
|
||||
|
||||
def _make_test(f: str, arch: str, test_type: str):
|
||||
def test(self):
|
||||
tests = _get_tests(f, arch)
|
||||
|
|
@ -160,7 +131,7 @@ def _make_test(f: str, arch: str, test_type: str):
|
|||
print(f"{name}: {len(to_test)} passed, {skipped} skipped")
|
||||
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")
|
||||
# Compare disasm->reassemble with original encoding (filter reserved bit cases where LLVM can't reproduce)
|
||||
llvm_bytes = _compile_asm_batch([t[1] for t in to_test], arch, mcpu)
|
||||
llvm_bytes = llvm_assemble([t[1] for t in to_test], mcpu, get_mattr(arch))
|
||||
valid = [(enc, d, llvm) for (enc, d), llvm in zip(to_test, llvm_bytes) if llvm == enc]
|
||||
print(f"{name}: {len(valid)}/{len(to_test)} matched LLVM encoding")
|
||||
for enc, _, llvm in valid: self.assertEqual(llvm, enc)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Test that invalid instructions raise exceptions through the mock GPU stack."""
|
||||
import unittest, subprocess, os, time
|
||||
import unittest, subprocess, os, sys, time
|
||||
|
||||
class TestMockGPUInvalidInstruction(unittest.TestCase):
|
||||
def test_unsupported_instruction_raises(self):
|
||||
|
|
@ -43,7 +43,7 @@ dev.synchronize()
|
|||
env["HCQDEV_WAIT_TIMEOUT_MS"] = "10000"
|
||||
|
||||
st = time.perf_counter()
|
||||
result = subprocess.run(["python", "-c", test_code], env=env, capture_output=True, text=True, timeout=60)
|
||||
result = subprocess.run([sys.executable, "-c", test_code], env=env, capture_output=True, text=True, timeout=60)
|
||||
elapsed = time.perf_counter() - st
|
||||
|
||||
self.assertNotEqual(result.returncode, 0, "should have raised")
|
||||
|
|
|
|||
|
|
@ -1,27 +1,14 @@
|
|||
#!/usr/bin/env python3
|
||||
import unittest, subprocess
|
||||
import unittest
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
from test.amd.helpers import get_llvm_mc
|
||||
from test.amd.helpers import llvm_assemble
|
||||
from test.amd.disasm import disasm
|
||||
|
||||
def llvm_assemble(asm: str) -> bytes:
|
||||
"""Assemble using llvm-mc and return bytes."""
|
||||
result = subprocess.run(
|
||||
[get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
|
||||
input=asm, capture_output=True, text=True
|
||||
)
|
||||
out = b''
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
enc = enc.strip('[]').replace('0x', '').replace(',', '')
|
||||
out += bytes.fromhex(enc)
|
||||
if not out: raise ValueError(f"no encoding found: {result.stdout} {result.stderr}")
|
||||
return out
|
||||
def _asm(asm: str) -> bytes: return llvm_assemble([asm], 'gfx1100', '+real-true16,+wavefrontsize32')[0]
|
||||
|
||||
class TestRDNA3Asm(unittest.TestCase):
|
||||
def test_full_program(self):
|
||||
"""Test the full program from rdna3fun.py matches llvm-mc output."""
|
||||
"""Test the full program from rdna3fun.py matches LLVM output."""
|
||||
program = [
|
||||
v_bfe_u32(v[1], v[0], 10, 10),
|
||||
s_load_b128(s[4:7], s[0:1], NULL),
|
||||
|
|
@ -45,52 +32,35 @@ class TestRDNA3Asm(unittest.TestCase):
|
|||
s_endpgm(),
|
||||
]
|
||||
|
||||
asm = """
|
||||
v_bfe_u32 v1, v0, 10, 10
|
||||
s_load_b128 s[4:7], s[0:1], null
|
||||
v_and_b32_e32 v0, 0x3FF, v0
|
||||
s_mulk_i32 s3, 0x87
|
||||
v_mad_u64_u32 v[1:2], null, s2, 3, v[1:2]
|
||||
v_mul_u32_u24_e32 v0, 45, v0
|
||||
v_ashrrev_i32_e32 v2, 31, v1
|
||||
v_add3_u32 v0, v0, s3, v1
|
||||
v_lshlrev_b64 v[2:3], 2, v[1:2]
|
||||
v_ashrrev_i32_e32 v1, 31, v0
|
||||
v_lshlrev_b64 v[0:1], 2, v[0:1]
|
||||
s_waitcnt lgkmcnt(0)
|
||||
v_add_co_u32 v2, vcc_lo, s6, v2
|
||||
v_add_co_ci_u32_e32 v3, vcc_lo, s7, v3, vcc_lo
|
||||
v_add_co_u32 v0, vcc_lo, s4, v0
|
||||
global_load_b32 v2, v[2:3], off
|
||||
v_add_co_ci_u32_e32 v1, vcc_lo, s5, v1, vcc_lo
|
||||
s_waitcnt vmcnt(0)
|
||||
global_store_b32 v[0:1], v2, off
|
||||
s_endpgm
|
||||
"""
|
||||
expected = llvm_assemble(asm)
|
||||
for inst,rt in zip(program, asm.strip().split("\n")): print(f"{disasm(inst):50s} {rt}")
|
||||
actual = b''.join(inst.to_bytes() for inst in program)
|
||||
self.assertEqual(actual, expected)
|
||||
asm_lines = [
|
||||
"v_bfe_u32 v1, v0, 10, 10", "s_load_b128 s[4:7], s[0:1], null", "v_and_b32_e32 v0, 0x3FF, v0",
|
||||
"s_mulk_i32 s3, 0x87", "v_mad_u64_u32 v[1:2], null, s2, 3, v[1:2]", "v_mul_u32_u24_e32 v0, 45, v0",
|
||||
"v_ashrrev_i32_e32 v2, 31, v1", "v_add3_u32 v0, v0, s3, v1", "v_lshlrev_b64 v[2:3], 2, v[1:2]",
|
||||
"v_ashrrev_i32_e32 v1, 31, v0", "v_lshlrev_b64 v[0:1], 2, v[0:1]", "s_waitcnt lgkmcnt(0)",
|
||||
"v_add_co_u32 v2, vcc_lo, s6, v2", "v_add_co_ci_u32_e32 v3, vcc_lo, s7, v3, vcc_lo",
|
||||
"v_add_co_u32 v0, vcc_lo, s4, v0", "global_load_b32 v2, v[2:3], off",
|
||||
"v_add_co_ci_u32_e32 v1, vcc_lo, s5, v1, vcc_lo", "s_waitcnt vmcnt(0)",
|
||||
"global_store_b32 v[0:1], v2, off", "s_endpgm",
|
||||
]
|
||||
expected = llvm_assemble(asm_lines, 'gfx1100', '+real-true16,+wavefrontsize32')
|
||||
for inst, rt in zip(program, asm_lines): print(f"{disasm(inst):50s} {rt}")
|
||||
for inst, exp in zip(program, expected): self.assertEqual(inst.to_bytes(), exp)
|
||||
|
||||
def test_sop2_s_add_u32(self):
|
||||
inst = SOP2(SOP2Op.S_ADD_U32, s[3], s[0], s[1])
|
||||
expected = llvm_assemble("s_add_u32 s3, s0, s1")
|
||||
self.assertEqual(inst.to_bytes(), expected)
|
||||
self.assertEqual(inst.to_bytes(), _asm("s_add_u32 s3, s0, s1"))
|
||||
|
||||
def test_vop2_v_and_b32_inline_const(self):
|
||||
inst = v_and_b32_e32(v[0], 10, v[0])
|
||||
expected = llvm_assemble("v_and_b32_e32 v0, 10, v0")
|
||||
self.assertEqual(inst.to_bytes(), expected)
|
||||
self.assertEqual(inst.to_bytes(), _asm("v_and_b32_e32 v0, 10, v0"))
|
||||
|
||||
def test_sopp_s_endpgm(self):
|
||||
inst = s_endpgm()
|
||||
expected = llvm_assemble("s_endpgm")
|
||||
self.assertEqual(inst.to_bytes(), expected)
|
||||
self.assertEqual(inst.to_bytes(), _asm("s_endpgm"))
|
||||
|
||||
def test_sop1_s_mov_b32(self):
|
||||
inst = s_mov_b32(s[0], s[1])
|
||||
expected = llvm_assemble("s_mov_b32 s0, s1")
|
||||
self.assertEqual(inst.to_bytes(), expected)
|
||||
self.assertEqual(inst.to_bytes(), _asm("s_mov_b32 s0, s1"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re, subprocess, os
|
||||
import unittest, io, sys, re
|
||||
from tinygrad import Device
|
||||
from tinygrad.renderer.amd import detect_format
|
||||
from test.amd.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr
|
||||
from test.amd.helpers import llvm_assemble, llvm_disasm, get_target, get_mattr
|
||||
from test.amd.disasm import disasm
|
||||
|
||||
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
|
|
@ -31,45 +31,18 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
|||
|
||||
def compile_asm(instr: str, arch: str = 'rdna3') -> bytes:
|
||||
"""Compile a single instruction using LLVM."""
|
||||
return compile_asm_batch([instr], arch)[0]
|
||||
return llvm_assemble([instr], get_target(arch), get_mattr(arch))[0]
|
||||
|
||||
def compile_asm_batch(instrs: list[str], arch: str = 'rdna3') -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={get_target(arch)}', f'-mattr={get_mattr(arch)}', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
encodings = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
||||
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
|
||||
return encodings
|
||||
"""Compile multiple instructions with a single LLVM emission."""
|
||||
return llvm_assemble(instrs, get_target(arch), get_mattr(arch))
|
||||
|
||||
def compile_and_disasm_batch(instrs: list[str], arch: str = 'rdna3') -> list[str]:
|
||||
"""Compile instructions with LLVM and get LLVM's disassembly."""
|
||||
import tempfile
|
||||
if not instrs: return []
|
||||
mcpu, mattr = get_target(arch), get_mattr(arch)
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
|
||||
obj_path = f.name
|
||||
try:
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
|
||||
results: list[str] = []
|
||||
for line in result.stdout.splitlines():
|
||||
if '//' not in line: continue
|
||||
instr = line.split('//')[0].strip()
|
||||
if instr: results.append(instr)
|
||||
return results[:len(instrs)]
|
||||
finally:
|
||||
os.unlink(obj_path)
|
||||
code = b''.join(llvm_assemble(instrs, mcpu, mattr))
|
||||
return llvm_disasm(code, mcpu, mattr)[:len(instrs)]
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
|
||||
class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
|
|
@ -174,12 +147,12 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
|||
if our_disasm is None:
|
||||
disasm_skipped += 1
|
||||
elif idx in disasm_llvm_map:
|
||||
llvm_disasm = disasm_llvm_map[idx]
|
||||
if our_disasm == llvm_disasm:
|
||||
llvm_disasm_str = disasm_llvm_map[idx]
|
||||
if our_disasm == llvm_disasm_str:
|
||||
disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm_str}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
|
|||
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
|
||||
except Exception as e: exc = e
|
||||
(t:=threading.Thread(target=worker, daemon=True)).start()
|
||||
t.join(timeout=1)
|
||||
t.join(timeout=5)
|
||||
if exc is not None: raise exc
|
||||
if t.is_alive(): raise RuntimeError("rocprof decoder timeout")
|
||||
return occupancy_records, wave_insts
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ def extract_packet_encodings():
|
|||
|
||||
def extract_cdna_packet_sizes():
|
||||
"""Extract CDNA pkt_fmt -> size mapping by running rocprof decoder to populate its hash table."""
|
||||
if not _load_lib(): return None
|
||||
from test.amd.test_sqtt_examples import run_rocprof_decoder
|
||||
|
||||
if not (pkl_path := next((EXAMPLES_DIR / "gfx950").glob("*.pkl"), None)): return None
|
||||
|
|
@ -119,8 +120,7 @@ class TestSQTTMatchesBinary(unittest.TestCase):
|
|||
def test_cdna_packet_sizes(self):
|
||||
"""Extract and verify CDNA pkt_fmt -> size mapping from rocprof's hash table."""
|
||||
if not (EXAMPLES_DIR / "gfx950").exists(): self.skipTest("no CDNA examples")
|
||||
pkt_sizes = extract_cdna_packet_sizes()
|
||||
self.assertIsNotNone(pkt_sizes, "failed to extract CDNA packet sizes")
|
||||
if not (pkt_sizes := extract_cdna_packet_sizes()): self.skipTest("rocprof-trace-decoder not installed")
|
||||
for pkt_fmt, size in CDNA_PKT_SIZES.items():
|
||||
with self.subTest(pkt_fmt=pkt_fmt): self.assertEqual(pkt_sizes.get(pkt_fmt), size)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,37 +9,47 @@ from __future__ import annotations
|
|||
import ctypes, functools, re, platform, subprocess, tempfile
|
||||
from typing import Any, Callable
|
||||
|
||||
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) in MXCSR to match RDNA3 default float mode
|
||||
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode
|
||||
# x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24)
|
||||
# Only applied during emulator execution, restored afterward to avoid breaking hypothesis tests
|
||||
@functools.cache
|
||||
def _get_mxcsr_lib():
|
||||
if platform.machine() not in ('x86_64', 'AMD64'): return None
|
||||
try:
|
||||
def _get_ftz_lib():
|
||||
machine = platform.machine()
|
||||
if machine in ('x86_64', 'AMD64'):
|
||||
src = b'''
|
||||
unsigned int get_mxcsr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;}
|
||||
void set_mxcsr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));}
|
||||
unsigned int get_fpcr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;}
|
||||
void set_fpcr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));}
|
||||
'''
|
||||
ftz_bits = 0x8040 # DAZ (bit 6) + FTZ (bit 15)
|
||||
elif machine in ('arm64', 'aarch64'):
|
||||
src = b'''
|
||||
unsigned int get_fpcr(void){unsigned long long v;__asm__ __volatile__("mrs %0,fpcr":"=r"(v));return(unsigned int)v;}
|
||||
void set_fpcr(unsigned int m){unsigned long long v=m;__asm__ __volatile__("msr fpcr,%0"::"r"(v));}
|
||||
'''
|
||||
ftz_bits = 1 << 24 # FZ (bit 24)
|
||||
else: return None, 0
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix='.so', delete=False) as f:
|
||||
subprocess.check_output(['clang', '-shared', '-O2', '-x', 'c', '-', '-o', f.name], input=src)
|
||||
lib = ctypes.CDLL(f.name)
|
||||
lib.get_mxcsr.restype = ctypes.c_uint32
|
||||
lib.set_mxcsr.argtypes = [ctypes.c_uint32]
|
||||
return lib
|
||||
except Exception: return None
|
||||
lib.get_fpcr.restype = ctypes.c_uint32
|
||||
lib.set_fpcr.argtypes = [ctypes.c_uint32]
|
||||
return lib, ftz_bits
|
||||
except Exception: return None, 0
|
||||
|
||||
class _MXCSRContext:
|
||||
"""Context manager to set DAZ+FTZ during emulator execution and restore afterward."""
|
||||
__slots__ = ('_saved',)
|
||||
def __enter__(self):
|
||||
lib = _get_mxcsr_lib()
|
||||
lib, ftz_bits = _get_ftz_lib()
|
||||
if lib is None: return self
|
||||
self._saved = lib.get_mxcsr()
|
||||
lib.set_mxcsr(self._saved | 0x8040) # DAZ (bit 6) + FTZ (bit 15)
|
||||
self._saved = lib.get_fpcr()
|
||||
lib.set_fpcr(self._saved | ftz_bits)
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
lib = _get_mxcsr_lib()
|
||||
lib, _ = _get_ftz_lib()
|
||||
if lib is None or not hasattr(self, '_saved'): return
|
||||
lib.set_mxcsr(self._saved)
|
||||
lib.set_fpcr(self._saved)
|
||||
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from tinygrad.dtype import dtypes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue