amd/sqtt: add rdna4 and cdna sqtt examples (#14251)

* amd/sqtt: add rdna4 and cdna sqtt examples

* work

* comment out rdna and cdna tests
This commit is contained in:
qazal 2026-01-20 07:11:48 -05:00 committed by GitHub
commit 4548fcc1b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 52 additions and 16 deletions

3
.gitignore vendored
View file

@ -58,10 +58,11 @@ weights
*.lprof
comgr_*
*.pkl
!extra/sqtt/examples/**/*.pkl
site/
profile_stats
*.log
target
.mypy_cache
mutants
.mutmut-cache
.mutmut-cache

View file

@ -26,9 +26,11 @@ def get_llvm_objdump():
ARCH_TO_TARGET:dict[str, list[str]] = {
"rdna3":["gfx1100"],
"rdna4":["gfx1200"],
"cdna":["gfx942"],
"cdna":["gfx942", "gfx950"],
}
TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items() for t in targets}
def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
def get_mattr(arch:str) -> str:

View file

@ -10,6 +10,7 @@ from extra.assembly.amd.autogen.rdna3.ins import SOPP
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets)
from extra.assembly.amd.test.helpers import TARGET_TO_ARCH
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
# INST ops for non-traced SIMDs (excluded from instruction count)
@ -23,7 +24,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
# ROCPROF DECODER
# ═══════════════════════════════════════════════════════════════════════════════
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
"""Run rocprof decoder on SQTT blobs, returning raw occupancy and instruction records."""
image, sections, _ = elf_loader(lib)
text = next((sh for sh in sections if sh.name == ".text"), None)
@ -58,6 +59,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
wave_insts.append([(inst.time, inst.stall) for inst in insts])
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
arch = TARGET_TO_ARCH[target]
@rocprof.rocprof_trace_decoder_isa_callback_t
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, _):
offset = pc.address - base
@ -65,8 +67,9 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
mem_size_ptr[0] = 0
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
try:
inst = decode_inst(image[offset:])
inst = decode_inst(image[offset:], arch=arch)
mem_size_ptr[0] = inst._size()
# this could be an error in our decode_inst
except (ValueError, AssertionError):
mem_size_ptr[0] = 0
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
@ -89,10 +92,12 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
return occupancy_records, wave_insts
class TestSQTTExamples(unittest.TestCase):
target = "gfx1100"
@classmethod
def setUpClass(cls):
cls.examples = {}
for pkl_path in sorted(EXAMPLES_DIR.glob("*.pkl")):
for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")):
with open(pkl_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
@ -141,25 +146,26 @@ class TestSQTTExamples(unittest.TestCase):
all_packets = [p for e in events for p in decode(e.blob)]
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
expected = {
"profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929],
"profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923],
"profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912],
"profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912],
}
def test_packet_counts(self):
expected = {
"profile_empty_run_0": [559, 600],
"profile_empty_run_1": [517, 570],
"profile_gemm_run_0": [1489, 604, 1789, 466, 17570, 407],
"profile_gemm_run_1": [1453, 604, 1871, 493, 17827, 460],
"profile_plus_run_0": [695, 668],
"profile_plus_run_1": [663, 593],
}
for name, (events, *_) in self.examples.items():
with self.subTest(example=name):
if not self.expected.get(name): continue
counts = [len(list(decode(e.blob))) for e in events]
self.assertEqual(counts, expected[name], f"packet count mismatch in {name}")
self.assertEqual(counts, self.expected[name], f"packet count mismatch in {name}")
def test_rocprof_wave_times_match(self):
"""Wave start/end times must match rocprof exactly."""
for name, (events, lib, base) in self.examples.items():
with self.subTest(example=name):
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base)
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
# extract from rocprof occupancy records
roc_starts: dict[tuple[int, int, int], int] = {}
roc_waves: list[tuple[int, int]] = []
@ -181,7 +187,7 @@ class TestSQTTExamples(unittest.TestCase):
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
for name, (events, lib, base) in self.examples.items():
with self.subTest(example=name):
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base)
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
# skip last inst per wave (s_endpgm) - it needs special handling (time + duration instead of time + stall)
roc_insts = [time + stall for insts in wave_insts for time, stall in insts[:-1]]
# extract from our decoder
@ -195,5 +201,9 @@ class TestSQTTExamples(unittest.TestCase):
for _ in range(bin(p.mask).count('1')): our_insts.append(p._time)
self.assertEqual(sorted(our_insts), sorted(roc_insts), f"instruction times mismatch in {name}")
#class TestSQTTExamplesRDNA4(TestSQTTExamples): target = "gfx1200"
#class TestSQTTExamplesCDNA(TestSQTTExamples): target = "gfx950"
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,23 @@
import os, subprocess
from pathlib import Path
from tinygrad.helpers import temp
EXAMPLES_DIR = Path(__file__).parent
PROFILE_PATH = Path(temp("profile.pkl", append_user=True))
EXAMPLES = [
"test.test_custom_kernel.TestCustomKernel.test_empty",
"test.test_tiny.TestTiny.test_plus",
"test.test_tiny.TestTiny.test_gemm",
]
if __name__ == "__main__":
arch = subprocess.check_output(["python", "-c", "from tinygrad import Device; print(Device['AMD'].arch)"], text=True,
env={**os.environ, "DEBUG":"0"}).rstrip()
(EXAMPLES_DIR/arch).mkdir(exist_ok=True)
for test in EXAMPLES:
for i in range(2):
subprocess.run(["python", "-m", "unittest", test], cwd=EXAMPLES_DIR.parent.parent.parent,
env={**os.environ, "AMD":"1", "SQTT_LIMIT_SE":"-1", "VIZ":"-2"}, check=True)
PROFILE_PATH.rename(dest:=EXAMPLES_DIR/arch/f"profile_{test.split('.')[-1].replace('test_', '')}_run_{i}.pkl")
print(f"saved SQTT trace to {dest}")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.