mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
revert-943
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71b9bf1c58 |
1 changed files with 2 additions and 6 deletions
|
|
@ -1,6 +1,5 @@
|
|||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad import dtypes, Tensor
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
|
|
@ -14,15 +13,12 @@ K = getenv("K", N)
|
|||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
INT_LOW = getenv("INT_LOW", 0)
|
||||
INT_HIGH = getenv("INT_HIGH", 10)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def init_matrix(rows, cols):
|
||||
rng = np.random.default_rng()
|
||||
if dtype_in in dtypes.ints:
|
||||
return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=_to_np_dtype(dtype_in))).realize()
|
||||
return Tensor(rng.random((rows, cols), dtype=np.float32).astype(_to_np_dtype(dtype_in))).realize()
|
||||
return Tensor.randint((rows, cols), dtype=dtype_in).realize()
|
||||
return Tensor.rand(rows, cols, dtype=dtype_in).realize()
|
||||
|
||||
a, b = init_matrix(M, K), init_matrix(K, N)
|
||||
for i in range(CNT):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue