mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
amd/sqtt: add rdna4 and cdna sqtt examples (#14251)
* amd/sqtt: add rdna4 and cdna sqtt examples * work * comment out rdna and cdna tests
This commit is contained in:
parent
2dc281b32a
commit
4548fcc1b8
34 changed files with 52 additions and 16 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -58,10 +58,11 @@ weights
|
|||
*.lprof
|
||||
comgr_*
|
||||
*.pkl
|
||||
!extra/sqtt/examples/**/*.pkl
|
||||
site/
|
||||
profile_stats
|
||||
*.log
|
||||
target
|
||||
.mypy_cache
|
||||
mutants
|
||||
.mutmut-cache
|
||||
.mutmut-cache
|
||||
|
|
|
|||
|
|
@ -26,9 +26,11 @@ def get_llvm_objdump():
|
|||
ARCH_TO_TARGET:dict[str, list[str]] = {
|
||||
"rdna3":["gfx1100"],
|
||||
"rdna4":["gfx1200"],
|
||||
"cdna":["gfx942"],
|
||||
"cdna":["gfx942", "gfx950"],
|
||||
}
|
||||
|
||||
TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items() for t in targets}
|
||||
|
||||
def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
|
||||
|
||||
def get_mattr(arch:str) -> str:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from extra.assembly.amd.autogen.rdna3.ins import SOPP
|
|||
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
|
||||
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
|
||||
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets)
|
||||
from extra.assembly.amd.test.helpers import TARGET_TO_ARCH
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
|
||||
# INST ops for non-traced SIMDs (excluded from instruction count)
|
||||
|
|
@ -23,7 +24,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
|
|||
# ROCPROF DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
|
||||
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
|
||||
"""Run rocprof decoder on SQTT blobs, returning raw occupancy and instruction records."""
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh for sh in sections if sh.name == ".text"), None)
|
||||
|
|
@ -58,6 +59,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
|
|||
wave_insts.append([(inst.time, inst.stall) for inst in insts])
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
||||
arch = TARGET_TO_ARCH[target]
|
||||
@rocprof.rocprof_trace_decoder_isa_callback_t
|
||||
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, _):
|
||||
offset = pc.address - base
|
||||
|
|
@ -65,8 +67,9 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
|
|||
mem_size_ptr[0] = 0
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
try:
|
||||
inst = decode_inst(image[offset:])
|
||||
inst = decode_inst(image[offset:], arch=arch)
|
||||
mem_size_ptr[0] = inst._size()
|
||||
# this could be an error in our decode_inst
|
||||
except (ValueError, AssertionError):
|
||||
mem_size_ptr[0] = 0
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
|
@ -89,10 +92,12 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
|
|||
return occupancy_records, wave_insts
|
||||
|
||||
class TestSQTTExamples(unittest.TestCase):
|
||||
target = "gfx1100"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.examples = {}
|
||||
for pkl_path in sorted(EXAMPLES_DIR.glob("*.pkl")):
|
||||
for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")):
|
||||
with open(pkl_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
|
|
@ -141,25 +146,26 @@ class TestSQTTExamples(unittest.TestCase):
|
|||
all_packets = [p for e in events for p in decode(e.blob)]
|
||||
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
|
||||
|
||||
expected = {
|
||||
"profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929],
|
||||
"profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923],
|
||||
"profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912],
|
||||
}
|
||||
def test_packet_counts(self):
|
||||
expected = {
|
||||
"profile_empty_run_0": [559, 600],
|
||||
"profile_empty_run_1": [517, 570],
|
||||
"profile_gemm_run_0": [1489, 604, 1789, 466, 17570, 407],
|
||||
"profile_gemm_run_1": [1453, 604, 1871, 493, 17827, 460],
|
||||
"profile_plus_run_0": [695, 668],
|
||||
"profile_plus_run_1": [663, 593],
|
||||
}
|
||||
for name, (events, *_) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
if not self.expected.get(name): continue
|
||||
counts = [len(list(decode(e.blob))) for e in events]
|
||||
self.assertEqual(counts, expected[name], f"packet count mismatch in {name}")
|
||||
self.assertEqual(counts, self.expected[name], f"packet count mismatch in {name}")
|
||||
|
||||
def test_rocprof_wave_times_match(self):
|
||||
"""Wave start/end times must match rocprof exactly."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base)
|
||||
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
|
||||
# extract from rocprof occupancy records
|
||||
roc_starts: dict[tuple[int, int, int], int] = {}
|
||||
roc_waves: list[tuple[int, int]] = []
|
||||
|
|
@ -181,7 +187,7 @@ class TestSQTTExamples(unittest.TestCase):
|
|||
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base)
|
||||
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
|
||||
# skip last inst per wave (s_endpgm) - it needs special handling (time + duration instead of time + stall)
|
||||
roc_insts = [time + stall for insts in wave_insts for time, stall in insts[:-1]]
|
||||
# extract from our decoder
|
||||
|
|
@ -195,5 +201,9 @@ class TestSQTTExamples(unittest.TestCase):
|
|||
for _ in range(bin(p.mask).count('1')): our_insts.append(p._time)
|
||||
self.assertEqual(sorted(our_insts), sorted(roc_insts), f"instruction times mismatch in {name}")
|
||||
|
||||
#class TestSQTTExamplesRDNA4(TestSQTTExamples): target = "gfx1200"
|
||||
|
||||
#class TestSQTTExamplesCDNA(TestSQTTExamples): target = "gfx950"
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
23
extra/sqtt/examples/generate_examples.py
Normal file
23
extra/sqtt/examples/generate_examples.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import os, subprocess
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import temp
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent
|
||||
PROFILE_PATH = Path(temp("profile.pkl", append_user=True))
|
||||
|
||||
EXAMPLES = [
|
||||
"test.test_custom_kernel.TestCustomKernel.test_empty",
|
||||
"test.test_tiny.TestTiny.test_plus",
|
||||
"test.test_tiny.TestTiny.test_gemm",
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
arch = subprocess.check_output(["python", "-c", "from tinygrad import Device; print(Device['AMD'].arch)"], text=True,
|
||||
env={**os.environ, "DEBUG":"0"}).rstrip()
|
||||
(EXAMPLES_DIR/arch).mkdir(exist_ok=True)
|
||||
for test in EXAMPLES:
|
||||
for i in range(2):
|
||||
subprocess.run(["python", "-m", "unittest", test], cwd=EXAMPLES_DIR.parent.parent.parent,
|
||||
env={**os.environ, "AMD":"1", "SQTT_LIMIT_SE":"-1", "VIZ":"-2"}, check=True)
|
||||
PROFILE_PATH.rename(dest:=EXAMPLES_DIR/arch/f"profile_{test.split('.')[-1].replace('test_', '')}_run_{i}.pkl")
|
||||
print(f"saved SQTT trace to {dest}")
|
||||
BIN
extra/sqtt/examples/gfx1100/profile_empty_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_empty_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_empty_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_empty_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_gemm_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_gemm_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_gemm_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_gemm_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_plus_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_plus_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_plus_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_plus_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_empty_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_empty_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_empty_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_empty_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_gemm_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_gemm_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_gemm_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_gemm_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_plus_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_plus_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_plus_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_plus_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_test_empty_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_test_empty_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_test_empty_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_test_empty_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_test_gemm_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_test_gemm_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_test_gemm_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_test_gemm_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_test_plus_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_test_plus_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_test_plus_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_test_plus_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx950/profile_empty_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx950/profile_empty_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx950/profile_empty_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx950/profile_empty_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx950/profile_gemm_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx950/profile_gemm_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx950/profile_gemm_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx950/profile_gemm_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx950/profile_plus_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx950/profile_plus_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx950/profile_plus_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx950/profile_plus_run_1.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue