tinygrad/test/external/external_fuzz_beam_timeout_recovery.py
2026-05-01 12:56:43 +03:00

31 lines
1.1 KiB
Python

#!/usr/bin/env python3
"""
Stress test for beam timeout + device recovery on AM devices.
Usage:
DEV=AMD python test/external/external_fuzz_beam_timeout_recovery.py
"""
from tinygrad import Tensor, Device
from tinygrad.helpers import Context
from tinygrad.runtime.ops_amd import AMDDevice
if __name__ == "__main__":
dev = Device["AMD"]
assert isinstance(dev, AMDDevice) and dev.is_am(), "not am"
N = 10000
for i in range(N):
with Context(DEBUG=0, BEAM=0):
a = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
b = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
c = a.matmul(b)
c.realize()
try: dev.synchronize(timeout=1)
except RuntimeError as e: print(e)
with Context(DEBUG=0, BEAM=0):
a = Tensor.ones(512, 512, device="AMD").contiguous().realize()
b = Tensor.ones(512, 512, device="AMD").contiguous().realize()
result = a.matmul(b).realize()[0, 0].item()
assert result == 512.0, f"iter {i}: got {result}"
print(f" iter {i+1}/{N}: ok")
print(f"=== All {N} iterations passed ===")