This commit is contained in:
George Hotz 2026-05-27 01:18:39 +00:00
commit 1d55b5618c
2 changed files with 3 additions and 7 deletions

View file

@ -129,16 +129,15 @@ class TestSQTTMapBase(unittest.TestCase):
def test_sqtt_cli(self):
for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")):
no_rewrites = ("--rewrites-path", "")
out = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "--ls")
out = run_cli("--profile-path", str(pkl_path), "--ls")
sqtt_traces = [l["value"].strip() for l in out if "SQTT" in l["value"]]
for name in sqtt_traces:
lines = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "-s", ansistrip(name))
lines = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
self.assertIn("Clk", lines[0]["value"])
waves = [r["clk"] for r in lines[2:] if "WAVE" in r["unit"]]
self.assertEqual(waves, sorted(waves), f"wave timestamps not monotonic in {name}")
with Context(DEBUG=2):
kernels = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "-s", "AMD")
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD")
self.assertEqual(len(kernels), len(self.examples[pkl_path.stem][1]))
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"

View file

@ -74,9 +74,6 @@ class TestQuantizeFP8(unittest.TestCase):
@needs_second_gpu
def test_multi(self):
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8))
try:
for dev in devs: Device[dev]
except Exception as e: self.skipTest(f"8 devices not available: {e}")
x = Tensor.empty(2048*8, 1024, dtype=dtypes.bfloat16, device=devs).uop.multi(0)
x = Tensor(x, device=devs)
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32, device=devs).contiguous()