mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Stress test for beam timeout + device recovery on AM devices.
|
|
|
|
Usage:
|
|
DEV=AMD python test/external/external_fuzz_beam_timeout_recovery.py
|
|
"""
|
|
from tinygrad import Tensor, Device
|
|
from tinygrad.helpers import Context
|
|
from tinygrad.runtime.ops_amd import AMDDevice
|
|
|
|
if __name__ == "__main__":
|
|
dev = Device["AMD"]
|
|
assert isinstance(dev, AMDDevice) and dev.is_am(), "not am"
|
|
|
|
N = 10000
|
|
for i in range(N):
|
|
with Context(DEBUG=0, BEAM=0):
|
|
a = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
|
|
b = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
|
|
c = a.matmul(b)
|
|
c.realize()
|
|
try: dev.synchronize(timeout=1)
|
|
except RuntimeError as e: print(e)
|
|
with Context(DEBUG=0, BEAM=0):
|
|
a = Tensor.ones(512, 512, device="AMD").contiguous().realize()
|
|
b = Tensor.ones(512, 512, device="AMD").contiguous().realize()
|
|
result = a.matmul(b).realize()[0, 0].item()
|
|
assert result == 512.0, f"iter {i}: got {result}"
|
|
print(f" iter {i+1}/{N}: ok")
|
|
print(f"=== All {N} iterations passed ===")
|