mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move is_dtype_supported to renderer (#16226)
This commit is contained in:
parent
d548f8d0f3
commit
172f9493e1
49 changed files with 211 additions and 245 deletions
|
|
@ -3,7 +3,7 @@ os.environ["WQKV"] = "1"
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, dtypes
|
||||
from tinygrad.device import is_dtype_supported, Device
|
||||
from tinygrad.device import Device
|
||||
from examples.mlperf.models.llama import Transformer
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class TestFlatLlama(unittest.TestCase):
|
|||
self.assertEqual(ref_logits.shape, flat_logits.shape)
|
||||
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8 not supported on this device")
|
||||
@unittest.skipUnless(dtypes.fp8e4m3 in Device[Device.DEFAULT].renderer.supported_dtypes(), "fp8 not supported on this device")
|
||||
def test_forward_fp8(self):
|
||||
import examples.mlperf.models.flat_llama as flat_llama_mod
|
||||
old_fp8 = flat_llama_mod.FP8
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from typing import Optional, Union, List, Any, Tuple, Callable
|
||||
import math
|
||||
|
||||
|
|
@ -13,7 +12,7 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
|||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
|
||||
return out.cast(mixed_precision_dtype) if mixed_precision_dtype in Device[Device.DEFAULT].renderer.supported_dtypes() else out
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
|
||||
|
|
@ -238,7 +237,7 @@ class UNetModel:
|
|||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
if is_dtype_supported(mixed_precision_dtype):
|
||||
if mixed_precision_dtype in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
emb = emb.cast(mixed_precision_dtype)
|
||||
ctx = ctx.cast(mixed_precision_dtype)
|
||||
x = x .cast(mixed_precision_dtype)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, system, DEV
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm
|
||||
from test.helpers import needs_second_gpu
|
||||
|
|
@ -85,7 +84,7 @@ def verify_asm_gemm_k_sharded_3d(batch:int, M:int, N:int, K:int, dtype=dtypes.fl
|
|||
|
||||
# 128x smaller than usual
|
||||
# uses the UOp GEMM, runs on non CDNA4 and CI
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "need half")
|
||||
class TestGemm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if is_cdna4(): self.skipTest("shapes are too small for the assembly GEMM")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import unittest, math
|
|||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DTYPES_DICT
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
from test.helpers import not_support_multi_device
|
||||
|
||||
|
|
@ -39,12 +38,10 @@ class TestMovedConstFolding(unittest.TestCase):
|
|||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# folded
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
|
@ -114,7 +111,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
|
|||
def test_sum_output_dtype(self):
|
||||
# sum output dtype can be different from input
|
||||
for dt in DTYPES_DICT.values():
|
||||
if is_dtype_supported(dt):
|
||||
if dt in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
|
||||
assert t.sum().dtype == t.contiguous().sum().dtype
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import contextlib, unittest, math
|
|||
import numpy as np
|
||||
import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI, EMULATED_DTYPES
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
|
|
@ -17,11 +16,13 @@ pytestmark = pytest.mark.filterwarnings("ignore")
|
|||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) or v in dtypes.fp8s+(dtypes.half,dtypes.bfloat16,dtypes.long)]
|
||||
if dtype in (dtypes.long, dtypes.ulong) and (not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes)):
|
||||
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and v in supported_dtypes or v in dtypes.fp8s+(dtypes.half,dtypes.bfloat16,dtypes.long)]
|
||||
if dtype in (dtypes.long, dtypes.ulong) and (dtype not in supported_dtypes or dtypes.long in EMULATED_DTYPES.tolist(dtypes)):
|
||||
return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support
|
||||
if not is_dtype_supported(dtype) and dtype not in dtypes.fp8s+(dtypes.half,dtypes.bfloat16): return []
|
||||
if dtype not in supported_dtypes and dtype not in dtypes.fp8s+(dtypes.half,dtypes.bfloat16): return []
|
||||
return dts
|
||||
|
||||
def _to_torch_storage_type(dtype:DType):
|
||||
|
|
@ -60,7 +61,7 @@ class TestDType(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 0x10, allow_subnormal=is_dtype_supported(cls.DTYPE))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 0x10, allow_subnormal=cls.DTYPE in supported_dtypes)
|
||||
|
||||
def test_to_np(self):
|
||||
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
|
||||
|
|
@ -115,7 +116,7 @@ class TestDType(unittest.TestCase):
|
|||
for dt in dtypes:
|
||||
arr = np.asarray(data).astype(dt)
|
||||
tensor = Tensor(arr)
|
||||
if not is_dtype_supported(tensor.dtype): continue
|
||||
if tensor.dtype not in supported_dtypes: continue
|
||||
tin = tensor.numpy()
|
||||
tor = torch.as_tensor(arr).detach().numpy()
|
||||
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
||||
|
|
@ -271,7 +272,7 @@ class TestFloatDType(TestDType):
|
|||
_test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32,
|
||||
[0, 0, 1])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.double), f"no double on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.double in supported_dtypes, f"no double on {Device.DEFAULT}")
|
||||
class TestDoubleDType(TestDType):
|
||||
DTYPE = dtypes.double
|
||||
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \
|
||||
|
|
@ -478,7 +479,7 @@ class TestImplicitFunctionTypeChange(unittest.TestCase):
|
|||
class TestTensorMethod(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_abs_diff(self, dt):
|
||||
if dt == dtypes.bool or not is_dtype_supported(dt): return
|
||||
if dt == dtypes.bool or dt not in supported_dtypes: return
|
||||
a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
|
||||
ret = (a - b).abs()
|
||||
np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))
|
||||
|
|
@ -486,11 +487,11 @@ class TestTensorMethod(unittest.TestCase):
|
|||
class TestDtypeUsage(unittest.TestCase):
|
||||
def test_max_w_alu(self):
|
||||
for d in dtypes.ints:
|
||||
if is_dtype_supported(d):
|
||||
if d in supported_dtypes:
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
# TODO: helper_test_op breaks in unrelated part
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from tinygrad import Context, Tensor, dtypes, Device
|
|||
from tinygrad.dtype import DType, truncate, fp8_to_float
|
||||
from tinygrad.helpers import CI, EMULATED_DTYPES, DEV, getenv
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.runtime.ops_python import from_storage_scalar
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
|
@ -41,6 +40,8 @@ if ((DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"})
|
|||
# transcendental isn't accurate enough
|
||||
if Ops.SQRT not in Device[Device.DEFAULT].renderer.code_for_op: unary_operations.remove((Tensor.sqrt, np.sqrt))
|
||||
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
class ht:
|
||||
float64 = strat.floats(width=64, allow_subnormal=False)
|
||||
float32 = strat.floats(width=32, allow_subnormal=False)
|
||||
|
|
@ -71,7 +72,7 @@ def universal_test(a, b, dtype, op):
|
|||
numpy_value = truncate[dtype](op[1](ta.numpy(), tb.numpy()).item())
|
||||
else: tensor_value, numpy_value = (op[0](ta, tb)).numpy(), op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
|
||||
if dtype not in supported_dtypes or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
|
||||
fe, fm = dtypes.finfo(dtype)
|
||||
atol, rtol = 2 ** (2 - (1 << (fe - 1))), 2 ** (-fm)
|
||||
else: atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1),
|
||||
|
|
@ -87,7 +88,8 @@ def universal_test_unary(a, dtype, op):
|
|||
if op[0] == Tensor.log and a <= 0: return
|
||||
if dtype in dtypes.fp8s:
|
||||
# denormals are zero
|
||||
if dtype in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dtype) and abs(ta.numpy().item()) < 0.015625: return
|
||||
if (dtype in EMULATED_DTYPES.tolist(dtypes) or dtype not in supported_dtypes
|
||||
and abs(ta.numpy().item()) < 0.015625): return
|
||||
tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype)
|
||||
numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item())
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
|
|
@ -119,14 +121,14 @@ def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
|
|||
np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) else 1e-7)
|
||||
|
||||
class TestDTypeALU(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float64), f"no float64 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float64 in supported_dtypes, f"no float64 on {Device.DEFAULT}")
|
||||
@given(ht.float64, ht.float64, strat.sampled_from(binary_operations))
|
||||
def test_float64(self, a, b, op): universal_test(a, b, dtypes.float64, op)
|
||||
|
||||
@given(ht.float32, ht.float32, strat.sampled_from(binary_operations))
|
||||
def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
|
||||
@given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
|
||||
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
|
||||
|
||||
|
|
@ -134,7 +136,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@Context(EMULATED_DTYPES="half")
|
||||
def test_emulated_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
|
||||
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
|
||||
def test_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
|
@ -144,7 +146,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
def test_emulated_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e4m3 in supported_dtypes, f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
def test_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
|
@ -154,7 +156,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
def test_emulated_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e5m2 in supported_dtypes, f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
def test_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
|
@ -164,12 +166,12 @@ class TestDTypeALU(unittest.TestCase):
|
|||
def test_emulated_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3fnuz), f"no fp8e4m3fnuz on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e4m3fnuz in supported_dtypes, f"no fp8e4m3fnuz on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3fnuz, ht.fp8e4m3fnuz, strat.sampled_from(binary_operations))
|
||||
def test_fp8e4m3fnuz(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3fnuz), from_storage_scalar(b, dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2fnuz), f"no fp8e5m2fnuz on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e5m2fnuz in supported_dtypes, f"no fp8e5m2fnuz on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2fnuz, ht.fp8e5m2fnuz, strat.sampled_from(binary_operations))
|
||||
def test_fp8e5m2fnuz(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2fnuz), from_storage_scalar(b, dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
|
||||
|
|
@ -187,7 +189,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
|
||||
@given(ht.float16, strat.sampled_from(unary_operations))
|
||||
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
|
||||
|
||||
|
|
@ -195,7 +197,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@Context(EMULATED_DTYPES="half")
|
||||
def test_emulated_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
|
||||
@given(ht.bfloat16, strat.sampled_from(unary_operations))
|
||||
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
|
|
@ -203,7 +205,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@Context(EMULATED_DTYPES="bfloat16")
|
||||
def test_emulated_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e4m3 in supported_dtypes, f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
|
||||
def test_fp8e4m3_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
|
|
@ -215,7 +217,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e5m2 in supported_dtypes, f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
def test_fp8e5m2_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
|
|
@ -227,13 +229,13 @@ class TestDTypeALU(unittest.TestCase):
|
|||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3fnuz), f"no fp8e4m3fnuz on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e4m3fnuz in supported_dtypes, f"no fp8e4m3fnuz on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3fnuz, strat.sampled_from(unary_operations))
|
||||
def test_fp8e4m3fnuz_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2fnuz), f"no fp8e5m2fnuz on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e5m2fnuz in supported_dtypes, f"no fp8e5m2fnuz on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2fnuz, strat.sampled_from(unary_operations))
|
||||
def test_fp8e5m2fnuz_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz) != 0.0)
|
||||
|
|
@ -254,15 +256,15 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint16), f"no uint16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.uint16 in supported_dtypes, f"no uint16 on {Device.DEFAULT}")
|
||||
@given(ht.uint16, ht.uint16, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint32), f"no uint32 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.uint32 in supported_dtypes, f"no uint32 on {Device.DEFAULT}")
|
||||
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.uint64 in supported_dtypes, f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
|
||||
|
||||
|
|
@ -291,15 +293,15 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.uint8, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint8_unary(self, a, op): universal_test_unary(a, dtypes.uint8, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint16), f"no uint16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.uint16 in supported_dtypes, f"no uint16 on {Device.DEFAULT}")
|
||||
@given(ht.uint16, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint16_unary(self, a, op): universal_test_unary(a, dtypes.uint16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint32), f"no uint32 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.uint32 in supported_dtypes, f"no uint32 on {Device.DEFAULT}")
|
||||
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.uint64 in supported_dtypes, f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
|
||||
|
||||
|
|
@ -352,21 +354,21 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||
if float_dtype not in supported_dtypes: float_dtype = dtypes.float32
|
||||
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||
|
||||
@unittest.skip("relied on hacks")
|
||||
@given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||
if float_dtype not in supported_dtypes: float_dtype = dtypes.float32
|
||||
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||
|
||||
@unittest.skip("relied on hacks")
|
||||
@given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
|
||||
if float_dtype not in supported_dtypes: float_dtype = dtypes.float32
|
||||
universal_test_cast(a, float_dtype, unsigned_dtype)
|
||||
|
||||
@unittest.expectedFailure
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, buffers
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.codegen import to_program
|
||||
|
|
@ -206,7 +206,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
def test_sum_acc_dtype(self):
|
||||
for tensor_dtype, acc_dtype in (
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
|
||||
if tensor_dtype in (dts:=Device[Device.DEFAULT].renderer.supported_dtypes()) and acc_dtype in dts:
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
realized_ast = a.schedule_linear().src[-1].src[0]
|
||||
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
|
|
@ -229,7 +229,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
(dtypes.float, dtypes.float16, dtypes.float16),
|
||||
)
|
||||
for tensor_dtype, acc_dtype, expected_dtype in tests:
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype):
|
||||
if tensor_dtype in (dts:=Device[Device.DEFAULT].renderer.supported_dtypes()) and acc_dtype in dts and expected_dtype in dts:
|
||||
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, dtypes, Context
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from extra.llama_kernels.fused_ce import fused_ce_loss
|
||||
|
||||
def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> None:
|
||||
|
|
@ -24,7 +23,7 @@ def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> No
|
|||
assert loss.allclose(ref, atol=2e-3, rtol=2e-3).item(), "forward mismatch"
|
||||
assert logits.grad.allclose(logits_ref.grad, atol=2e-3, rtol=2e-3).item(), "grad mismatch"
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16")
|
||||
class TestFusedCE(unittest.TestCase):
|
||||
def test_fused_ce_1_2_16(self): run_fused_ce(1, 2, 16, label_smoothing=0.2)
|
||||
def test_fused_ce_2_16_128(self): run_fused_ce(2, 16, 128)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import unittest, random
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
|
|
@ -911,7 +910,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
|||
|
||||
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
|
||||
def test_ops(self, dtype):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype not in Device[Device.DEFAULT].renderer.supported_dtypes(): return
|
||||
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
|
||||
for i in range(4):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import torch
|
|||
from tinygrad.helpers import getenv, CI, DEBUG, DEV, IMAGE, Context
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.renderer.cstyle import QCOMCLRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
|
|
@ -1508,7 +1507,8 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_sum_dtype_arg(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32))
|
||||
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64))
|
||||
if dtypes.float64 in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64))
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(dtype="")
|
||||
|
||||
|
|
@ -3379,7 +3379,6 @@ class TestOps(unittest.TestCase):
|
|||
t = (Tensor([0], dtype='int') | 0xFFFFFFFF).item()
|
||||
if not COMPILE_ONLY: assert t == -1
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True, low=0, high=255)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import torch
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from test.helpers import needs_second_gpu, slow
|
||||
|
||||
np.random.seed(1337)
|
||||
|
|
@ -142,7 +141,7 @@ class TestOptim(unittest.TestCase):
|
|||
|
||||
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "need half")
|
||||
def test_mixed_precision(self):
|
||||
old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
|
||||
# weight update would overflow without upcasting
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from functools import partial
|
|||
|
||||
from tinygrad import nn, dtypes, Tensor, Device, TinyJit, Variable
|
||||
from tinygrad.helpers import getenv, CI, OSX
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen import to_program
|
||||
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
|
@ -86,7 +85,7 @@ class TestRandomness(unittest.TestCase):
|
|||
self.assertTrue(r1.uop.is_realized, "tensor should be realized after .realize()")
|
||||
self.assertTrue(r2.uop.is_realized, "tensor should be realized after .realize()")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
@unittest.skipUnless(dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need float16 support")
|
||||
def test_rand_float16(self):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.float16)
|
||||
|
|
@ -211,7 +210,7 @@ class TestRandomness(unittest.TestCase):
|
|||
if not (x.src[0] == y.src[0]):
|
||||
print(f"{x.src[0]} != {y.src[0]}")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16 support")
|
||||
def test_rand_bfloat16(self):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
|
||||
|
|
@ -285,7 +284,7 @@ class TestRandomness(unittest.TestCase):
|
|||
|
||||
@given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16]))
|
||||
def test_randn_finite(self, default_float):
|
||||
if not is_dtype_supported(default_float): return
|
||||
if default_float not in Device[Device.DEFAULT].renderer.supported_dtypes(): return
|
||||
old_default_float = dtypes.default_float
|
||||
# low precision can result in inf from randn
|
||||
dtypes.default_float = default_float
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.device import Device, is_dtype_supported
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.codegen import to_program
|
||||
|
|
@ -96,7 +96,7 @@ class TestPTXFailures(unittest.TestCase):
|
|||
ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "need half")
|
||||
def test_gated_define_acc_with_half_dtype(self):
|
||||
a = Tensor.randn(32, 32, dtype=dtypes.half).realize()
|
||||
b = Tensor.randn(34, 32, dtype=dtypes.half).realize()
|
||||
|
|
|
|||
|
|
@ -8,12 +8,13 @@ from typing import cast
|
|||
from hypothesis import assume, given, strategies as strat
|
||||
|
||||
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat
|
||||
from tinygrad.helpers import CI, DEBUG, OSX, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.realize import compile_linear, run_linear
|
||||
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
|
||||
if to_prerealize:
|
||||
|
|
@ -105,7 +106,7 @@ class TestSchedule(unittest.TestCase):
|
|||
run_linear(*check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [[15]])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and OSX, "WEBGPU Metal backend is not accurate enough")
|
||||
def test_expand_buffer_before_cast(self):
|
||||
a = Tensor.randn(4, 2, 1).realize().permute((1, 0, 2))
|
||||
|
|
@ -715,7 +716,7 @@ class TestSchedule(unittest.TestCase):
|
|||
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
|
||||
|
||||
# TODO like openpilot with imagef
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_base_change_expand_expand(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.cast(dtypes.half).expand(2, 4, 4)
|
||||
|
|
@ -761,9 +762,9 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_conv2d(self): _test_conv2d(4)
|
||||
def test_conv2d_fused(self): _test_conv2d(4)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_conv2d_half(self): _test_conv2d(4, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")
|
||||
def test_conv2d_fused_half(self): _test_conv2d(4, dtype=dtypes.half)
|
||||
|
||||
|
|
@ -870,7 +871,7 @@ class TestSchedule(unittest.TestCase):
|
|||
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
|
||||
@unittest.skip("kernel count depends on input")
|
||||
def test_cast_padded_const(self, dt1, dt2):
|
||||
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
|
||||
assume(dt1 in supported_dtypes and dt2 in supported_dtypes)
|
||||
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dt2)
|
||||
run_linear(*check_schedule(casted_view, 0))
|
||||
|
|
@ -994,7 +995,7 @@ class TestSchedule(unittest.TestCase):
|
|||
self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_precompute_freqs_cis(self):
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
args = {"dim":32, "end":2048, "theta":10000}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from tinygrad import Tensor, GlobalCounters, Context, Device
|
|||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.helpers import DEBUG, get_single_element
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
# only support axis =-1
|
||||
|
|
@ -62,7 +61,7 @@ class TestFuse(unittest.TestCase):
|
|||
b = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes(), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_fuse_softmax_dtype(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from tinygrad import Tensor, Device, dtypes, nn
|
|||
from tinygrad.helpers import getenv, temp, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DTYPES_DICT
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
|
|
@ -454,7 +453,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
|
||||
def test_tensor_list_special_values(self):
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
|
||||
data = data + [-x for x in data]
|
||||
with np.errstate(over='ignore'): np.testing.assert_allclose(Tensor(data, dtype=dtypes.float16).numpy(), np.array(data).astype(np.float16))
|
||||
|
|
@ -571,7 +570,7 @@ class TestMoveTensor(unittest.TestCase):
|
|||
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
|
||||
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
|
||||
def test_to_preserves(self, src, dest, dtype, requires_grad):
|
||||
if not is_dtype_supported(dtype):
|
||||
if dtype not in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
return
|
||||
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
|
||||
if requires_grad: s.sum().backward()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from tinygrad.tensor import _to_np_dtype
|
|||
from tinygrad.helpers import Context, getenv, CI, DEV, OSX
|
||||
from test.backend.test_schedule import check_schedule
|
||||
from test.backend.test_dtype_alu import ht, dtypes_float
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
import math
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
|
@ -12,8 +11,10 @@ from hypothesis import given, settings, strategies as strat
|
|||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
class TestTranscendentalMath(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float64), f"no float64 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float64 in supported_dtypes, f"no float64 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
|
||||
@given(ht.float64, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
|
||||
def test_float64(self, x, op):
|
||||
|
|
@ -27,7 +28,7 @@ class TestTranscendentalMath(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
|
||||
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
|
||||
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
|
||||
([(Tensor.sin, np.sin)] if dtypes.ulong in supported_dtypes else [])))
|
||||
def test_float32(self, x, op):
|
||||
# wrong nan behavior on Vulkan
|
||||
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return
|
||||
|
|
@ -36,9 +37,9 @@ class TestTranscendentalMath(unittest.TestCase):
|
|||
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float32))),
|
||||
atol=2e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
|
||||
@given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
|
||||
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
|
||||
([(Tensor.sin, np.sin)] if dtypes.ulong in supported_dtypes else [])))
|
||||
def test_float16(self, x, op):
|
||||
# wrong nan behavior on Vulkan
|
||||
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return
|
||||
|
|
@ -53,7 +54,7 @@ class TestTranscendentalMath(unittest.TestCase):
|
|||
def test_exp_near_inf(self, dtype_x):
|
||||
# reordering compute might return inf
|
||||
dtype, x = dtype_x
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype not in supported_dtypes: return
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
y = Tensor([x], dtype=dtype).exp().numpy()
|
||||
expected = np.exp(np.array([x], dtype=_to_np_dtype(dtype)))
|
||||
|
|
@ -61,9 +62,9 @@ class TestTranscendentalMath(unittest.TestCase):
|
|||
|
||||
class TestFromFuzzer(unittest.TestCase):
|
||||
@given(strat.sampled_from(dtypes_float))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
@unittest.skipUnless(dtypes.ulong in supported_dtypes, "Needs ulong")
|
||||
def test_sin(self, dtype):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype not in supported_dtypes: return
|
||||
if dtype == dtypes.float64:
|
||||
# crashes in CI CUDA
|
||||
if DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}: return
|
||||
|
|
@ -85,7 +86,7 @@ class TestFromFuzzer(unittest.TestCase):
|
|||
|
||||
@given(strat.sampled_from(dtypes_float))
|
||||
def test_log2(self, dtype):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype not in supported_dtypes: return
|
||||
if dtype == dtypes.float64:
|
||||
# crashes in CI CUDA
|
||||
if DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}: return
|
||||
|
|
@ -104,7 +105,7 @@ class TestFromFuzzer(unittest.TestCase):
|
|||
|
||||
class TestFloat16Log2(unittest.TestCase):
|
||||
"""Tests for native float16 log2 implementation (no float32 cast)"""
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
|
||||
def test_float16_log2_basic(self):
|
||||
# basic values
|
||||
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
|
||||
|
|
@ -114,7 +115,7 @@ class TestFloat16Log2(unittest.TestCase):
|
|||
expected = np.log2(np.float16(val))
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
|
||||
def test_float16_log2_special(self):
|
||||
# special values: inf, -inf, nan, 0, negative
|
||||
|
|
@ -128,7 +129,7 @@ class TestFloat16Log2(unittest.TestCase):
|
|||
# log2(nan) = nan
|
||||
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
|
||||
def test_float16_log2_denormal(self):
|
||||
# test values near and below float16 min normal (6.1e-5)
|
||||
# these exercise the denormal handling path with 2^10 scaling
|
||||
|
|
@ -141,7 +142,7 @@ class TestFloat16Log2(unittest.TestCase):
|
|||
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
|
||||
|
||||
class TestTranscendentalSchedule(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
@unittest.skipUnless(dtypes.ulong in supported_dtypes, "Needs ulong")
|
||||
def test_transcendental_sin_fusion(self):
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
a = Tensor.empty(10)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from tinygrad.uop.ops import Ops, UOp, KernelInfo, AxisType, buffers
|
|||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.codegen import to_program
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from test.helpers import to_uops_list
|
||||
|
|
@ -141,9 +140,7 @@ class TestNonFloatUOps(TestUOps):
|
|||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_cmpne_int32(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
||||
def test_mul_bool(self): self._test_bop_fxn(Ops.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
||||
def test_where_float16(self):
|
||||
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
|
||||
|
||||
|
|
|
|||
5
test/external/external_test_onnx_backend.py
vendored
5
test/external/external_test_onnx_backend.py
vendored
|
|
@ -5,7 +5,6 @@ import onnx.backend.test
|
|||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import getenv, OSX
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.nn.onnx import OnnxRunner
|
||||
|
||||
# pip3 install tabulate
|
||||
|
|
@ -70,7 +69,7 @@ backend_test.exclude('test_resize_downsample_scales_linear_align_corners_cpu')
|
|||
backend_test.exclude('test_resize_downsample_scales_cubic_align_corners_cpu')
|
||||
|
||||
# about different dtypes
|
||||
if not is_dtype_supported(dtypes.float64):
|
||||
if dtypes.float64 not in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
backend_test.exclude('float64')
|
||||
backend_test.exclude('DOUBLE')
|
||||
# these have float64 inputs
|
||||
|
|
@ -80,7 +79,7 @@ if not is_dtype_supported(dtypes.float64):
|
|||
backend_test.exclude('test_einsum_*')
|
||||
backend_test.exclude('test_cumsum_*')
|
||||
|
||||
if not is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 not in Device[Device.DEFAULT].renderer.supported_dtypes():
|
||||
backend_test.exclude('float16')
|
||||
backend_test.exclude('FLOAT16')
|
||||
|
||||
|
|
|
|||
2
test/external/external_test_onnx_runner.py
vendored
2
test/external/external_test_onnx_runner.py
vendored
|
|
@ -2,7 +2,6 @@ import unittest, onnx, tempfile, pathlib
|
|||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from typing import Any
|
||||
from tinygrad.nn.onnx import OnnxRunner, OnnxPBParser, OnnxDataType
|
||||
from hypothesis import given, strategies as st
|
||||
|
|
@ -89,7 +88,6 @@ class TestOnnxRunner(unittest.TestCase):
|
|||
np.testing.assert_equal(output.numpy(), weights + 1)
|
||||
|
||||
all_dtypes = list(OnnxDataType)
|
||||
device_supported_dtypes = {odt for odt in OnnxDataType if is_dtype_supported(odt.to_dtype())}
|
||||
|
||||
class TestOnnxRunnerDtypes(unittest.TestCase):
|
||||
"""
|
||||
|
|
|
|||
3
test/external/fuzz_fast_idiv.py
vendored
3
test/external/fuzz_fast_idiv.py
vendored
|
|
@ -1,7 +1,6 @@
|
|||
import random
|
||||
import z3
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEV
|
||||
from tinygrad.uop.validate import uops_to_z3, z3_cdiv
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.decompositions import fast_idiv
|
||||
|
|
@ -16,7 +15,7 @@ if __name__ == "__main__":
|
|||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||
d = random.randint(1, max(1, u.arg[2])*2)
|
||||
if d in powers_of_two: continue
|
||||
expr = fast_idiv(DEV.target(Device.DEFAULT), u, d)
|
||||
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
|
||||
if expr is None: continue
|
||||
|
||||
solver = z3.Solver()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import examples.mlperf.metrics as metrics
|
|||
from tinygrad.helpers import fetch
|
||||
from test.helpers import slow
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
|
||||
# Audio generated with the command on MacOS:
|
||||
|
|
@ -53,7 +52,6 @@ def wer_helper(result: str, reference: str)->float:
|
|||
return wer
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
# TODO: WEBGPU GPU dispatch dimensions limit
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GPU dispatch dimensions limit")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import unittest, io
|
|||
from contextlib import redirect_stdout
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import OSX, DEV
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import compile_linear
|
||||
from tinygrad.codegen import to_program
|
||||
|
||||
|
|
@ -10,7 +9,6 @@ class TestCompileFailures(unittest.TestCase):
|
|||
def compile(self, out:Tensor):
|
||||
compile_linear(out.schedule_linear())
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
def test_interpolate_atari(self):
|
||||
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import unittest, math, struct, operator
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from tinygrad.dtype import DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.dtype import DTYPES_DICT, dtypes, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
|
||||
from tinygrad.helpers import getenv
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
|
@ -12,8 +11,8 @@ settings.register_profile("my_profile", max_examples=50, deadline=None, derandom
|
|||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and dt in Device[Device.DEFAULT].renderer.supported_dtypes()]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and dt in Device[Device.DEFAULT].renderer.supported_dtypes()]
|
||||
|
||||
FP8E4M3_MAX = 448.0
|
||||
FP8E5M2_MAX = 57344.0
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT=="NULL", "Don't run when testing non-NULL backends")
|
||||
class TestNULLSupportsDTypes(unittest.TestCase):
|
||||
def test_null_supports_ints_floats_bool(self):
|
||||
dts = dtypes.ints + dtypes.floats + (dtypes.bool,)
|
||||
not_supported = [dt for dt in dts if not is_dtype_supported(dt)]
|
||||
self.assertFalse(not_supported, msg=f"expected these dtypes to be supported by NULL: {not_supported}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import unittest, time, gc
|
||||
import numpy as np
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
|
@ -41,6 +40,8 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
|
|||
if all_jitted:
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
|
||||
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
class TestRealWorld(unittest.TestCase):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
|
|
@ -53,7 +54,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
dtypes.default_float = self.old_float
|
||||
|
||||
@slow
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
|
||||
def test_stable_diffusion(self):
|
||||
params = unet_params
|
||||
params["model_ch"] = 8
|
||||
|
|
@ -78,7 +79,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
exp_mem = 0.00037 if Device.DEFAULT == "CL" else 0.0002
|
||||
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, exp_mem, 37)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
|
||||
def test_llama(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.23, 118, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
|
||||
def test_gpt2(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
|
|
@ -147,7 +148,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, 0.12, 126)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
|
||||
def test_train_cifar_hyp(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
with Tensor.train():
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
|
@ -86,12 +85,12 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
def do_op_then_assert(self, dtype: DType, dim1, dim2, dim3):
|
||||
self._assert(dtype, Tensor.empty(dim1, dim2, 1).expand(-1, -1, dim3).contiguous())
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
@unittest.skipUnless(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
|
||||
def test_overflow(self):
|
||||
# 2**11, 2**11, 2**11 -> 2**33 will overflow when indexed
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, 2048)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
@unittest.skipUnless(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
|
||||
def test_overflow_sym(self):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
|
|
@ -108,12 +107,12 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
uops = self._schedule_render(a)
|
||||
assert all(uop.dtype is not dtypes.long for uop in uops)
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
@unittest.skipIf(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
|
||||
def test_int64_unsupported_overflow_sym(self):
|
||||
with self.assertRaises((KeyError, RuntimeError)):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
@unittest.skipIf(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
|
||||
@unittest.expectedFailure # bug in gpu dims limiting
|
||||
def test_int64_unsupported_overflow(self):
|
||||
with self.assertRaises((KeyError, RuntimeError)):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tinygrad import Device, Tensor, dtypes
|
|||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.uop.ops import Ops, UOp, buffers
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.device import Buffer, is_dtype_supported
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import DEV, Context
|
||||
from test.helpers import slow, replace_opts
|
||||
from tinygrad.engine.realize import run_linear
|
||||
|
|
@ -69,7 +69,6 @@ class TestTensorCores(unittest.TestCase):
|
|||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
|
||||
|
|
@ -78,7 +77,6 @@ class TestTensorCores(unittest.TestCase):
|
|||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_codegen(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2]
|
||||
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, dtype=tc.dtype_out)
|
||||
|
|
@ -98,7 +96,6 @@ class TestTensorCores(unittest.TestCase):
|
|||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_padded(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
|
||||
|
||||
# AMD compiler bug: AMD miscompiles non-zero padded tc kernels with -O3, producing wrong results, nans or hang (see #9606)
|
||||
|
|
@ -109,7 +106,6 @@ class TestTensorCores(unittest.TestCase):
|
|||
@unittest.skip("warp elements not duplicated properly across lanes")
|
||||
def test_tensor_cores_padded_amd(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
|
||||
|
||||
@Context(ALLOW_TF32=1)
|
||||
|
|
@ -140,7 +136,6 @@ class TestTensorCores(unittest.TestCase):
|
|||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_multi_reduce(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
if tc.dtype_in is dtypes.bfloat16: continue # <-- broken with numpy
|
||||
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
|
||||
golden_result = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
# similar to test/external/external_test_gpu_ast.py, but universal
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CUDA", "NV"} and CI, "slow on CUDA CI")
|
||||
|
|
@ -20,7 +19,7 @@ class TestSpecific(unittest.TestCase):
|
|||
w = Tensor.randn(2048, 512)
|
||||
(x @ w).reshape(1, 128, 4).contiguous().realize()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
@unittest.skipUnless(dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need float16 support")
|
||||
def test_big_vec_mul(self):
|
||||
# from LLaMA
|
||||
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
import unittest
|
||||
from extra.f16_decompress import u32_to_f16
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad import dtypes
|
||||
import numpy as np
|
||||
|
||||
class TestF16Decompression(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_u32_to_f16(self):
|
||||
a = Tensor.randn(50, dtype=dtypes.float16)
|
||||
f16_as_u32 = a.bitcast(dtypes.uint32)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ import numpy as np
|
|||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.nn import Linear
|
||||
from extra.fp8.fp8_linear import FP8Linear, convert_to_float8_training
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from test.helpers import not_support_multi_device, needs_second_gpu
|
||||
|
||||
BS, T, in_dim, out_dim = 16, 4, 128, 128
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.fp8e4m3 in Device[Device.DEFAULT].renderer.supported_dtypes(), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
class TestFP8Linear(unittest.TestCase):
|
||||
def setUp(self):
|
||||
Tensor.manual_seed(42)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import unittest
|
|||
import numpy as np
|
||||
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import temp, CI, DEV, Context
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
|
@ -468,7 +467,6 @@ class TestAssign(unittest.TestCase):
|
|||
self.assertEqual(GlobalCounters.kernel_count, 2) # currently conservative, forces contiguous
|
||||
np.testing.assert_allclose(a.numpy(), expected)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_setitem_half(self):
|
||||
a = Tensor.full((8,), 1.0, dtype=dtypes.half).contiguous().realize()
|
||||
b = Tensor.full((4,), 2.0, dtype=dtypes.half).contiguous().realize()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os, pathlib, tempfile, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, DTYPES_DICT
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, OSX, dedup
|
||||
|
|
@ -36,7 +35,6 @@ class TestTorchLoad(TempDirTestCase):
|
|||
# pytorch zip format
|
||||
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
|
||||
|
||||
# pytorch tar format
|
||||
|
|
@ -94,7 +92,6 @@ class TestRawDiskBuffer(unittest.TestCase):
|
|||
|
||||
pathlib.Path(tmp).unlink()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint8), "need uint8")
|
||||
class TestSafetensors(TempDirTestCase):
|
||||
def test_real_safetensors(self):
|
||||
import torch
|
||||
|
|
@ -184,7 +181,6 @@ class TestSafetensors(TempDirTestCase):
|
|||
def test_save_all_dtypes(self):
|
||||
for dtype in dedup(DTYPES_DICT.values()):
|
||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||
if not is_dtype_supported(dtype): continue
|
||||
path = self.tmp(f"ones.{dtype}.safetensors")
|
||||
ones = Tensor(np.random.rand(10,10), dtype=dtype)
|
||||
safe_save(get_state_dict(ones), path)
|
||||
|
|
@ -384,7 +380,6 @@ class TestDiskTensor(TempDirTestCase):
|
|||
assert ret.tolist() == [2827, 3341, 3855, 4369]
|
||||
|
||||
@unittest.skipIf(OSX or Device.DEFAULT == "CL", "new LLVM has an issue on OSX, DEV=CL gives the wrong output")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
|
||||
t.to(f"disk:{self.tmp('dt_bf16_disk_write_read_f32')}").realize()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest, math, subprocess
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes, DType, DTYPES_DICT
|
||||
from tinygrad.device import Device, is_dtype_supported
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import getenv, DEBUG, EMULATED_DTYPES
|
||||
from test.helpers import slow
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
|
@ -11,9 +11,10 @@ import torch
|
|||
settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and dt in supported_dtypes]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and dt in supported_dtypes]
|
||||
|
||||
FP8E4M3_MAX = 448.0
|
||||
FP8E5M2_MAX = 57344.0
|
||||
|
|
@ -25,7 +26,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
|||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
# denormals are zero
|
||||
if target_dtype in dtypes.floats and (not is_dtype_supported(target_dtype) or target_dtype in EMULATED_DTYPES.tolist(dtypes)):
|
||||
if target_dtype in dtypes.floats and (target_dtype not in supported_dtypes or target_dtype in EMULATED_DTYPES.tolist(dtypes)):
|
||||
fe, fm = dtypes.finfo(target_dtype)
|
||||
kwargs = {"atol":2 ** (2 - (1 << (fe - 1))), "rtol": 2 ** (-fm)}
|
||||
else: kwargs = {"rtol": {dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1,
|
||||
|
|
@ -58,7 +59,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(dtypes.int8 in supported_dtypes, f"no int8 on {Device.DEFAULT}")
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
|
|
@ -91,7 +92,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 in supported_dtypes:
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
|
|
@ -100,12 +101,12 @@ class TestTypeSpec(unittest.TestCase):
|
|||
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 in supported_dtypes:
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 in supported_dtypes:
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
|
|
@ -113,7 +114,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 in supported_dtypes:
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
|
||||
|
|
@ -132,10 +133,10 @@ class TestTypeSpec(unittest.TestCase):
|
|||
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
|
||||
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
|
||||
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
if dtypes.int16 in supported_dtypes:
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
if dtypes.float16 in supported_dtypes:
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
|
||||
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
|
||||
|
|
@ -149,7 +150,7 @@ class TestAutoCastType(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and d in supported_dtypes]))
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
|
|
@ -167,7 +168,7 @@ class TestAutoCastType(unittest.TestCase):
|
|||
# float16 can have larger precision errors
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need float16")
|
||||
def test_sum_dtype_arg(self):
|
||||
t = Tensor([40000, 40000], dtype=dtypes.float16)
|
||||
# default float16 sum returns in float16, overflowed in this case
|
||||
|
|
@ -188,10 +189,10 @@ class TestAutoCastType(unittest.TestCase):
|
|||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in dtypes.floats:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
if default_dtype not in supported_dtypes: continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
for dtype in dtypes.floats:
|
||||
if dtype not in supported_dtypes: continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
a = Tensor([1, 2, 3], dtype=dtype)
|
||||
|
|
@ -205,7 +206,7 @@ class TestAutoCastType(unittest.TestCase):
|
|||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
|
||||
@slow
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_mean_half_precision_underflow(self):
|
||||
N = 10000
|
||||
x = 0.001
|
||||
|
|
@ -213,7 +214,7 @@ class TestAutoCastType(unittest.TestCase):
|
|||
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
|
||||
|
||||
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_mean_half_precision_overflow(self):
|
||||
N = 256
|
||||
t = Tensor([60000] * N*N, dtype=dtypes.half).reshape(N, N)
|
||||
|
|
@ -222,7 +223,7 @@ class TestAutoCastType(unittest.TestCase):
|
|||
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_softmax_dtype(self):
|
||||
data = [1, 2, 3]
|
||||
t = Tensor(data, dtype=dtypes.half)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ from tinygrad import dtypes, Tensor, fetch, Device
|
|||
from tinygrad.helpers import disable_gc
|
||||
from tinygrad.llm.gguf import _ggml_iq_grid, ggml_data_to_tensor, gguf_load
|
||||
from tinygrad.runtime.autogen import ggml_common as _ggml
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
from gguf import GGUFReader, GGUFValueType, GGMLQuantizationType, GGML_QUANT_SIZES, dequantize, quantize
|
||||
from gguf.quants import IQ2_S, IQ3_S, IQ3_XXS
|
||||
|
||||
ggml_test_block_count = 4
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
class TestGGUFTables(unittest.TestCase):
|
||||
def test_iq2_s_grid_matches_gguf_py(self):
|
||||
|
|
@ -26,7 +26,7 @@ class TestGGUFTables(unittest.TestCase):
|
|||
grid = _ggml_iq_grid(Device.DEFAULT, _ggml.iq3s_grid, (512, 4)).numpy()
|
||||
np.testing.assert_equal(grid, IQ3_S.grid.reshape(512, 4))
|
||||
|
||||
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
|
||||
@unittest.skipUnless(dtypes.uint8 in supported_dtypes and dtypes.half in supported_dtypes, "Backend must support uint8 and half")
|
||||
class TestGGUF(unittest.TestCase):
|
||||
def test_load_tinyllama_q8_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
|
||||
def test_load_tinyllama_q4_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
|
||||
|
|
@ -60,7 +60,7 @@ class TestGGUF(unittest.TestCase):
|
|||
def test_dequantization_iq2_s(self): self._test_dequantization(GGMLQuantizationType.IQ2_S)
|
||||
def test_dequantization_iq4_xs(self): self._test_dequantization(GGMLQuantizationType.IQ4_XS)
|
||||
def test_dequantization_mxfp4(self): self._test_dequantization(GGMLQuantizationType.MXFP4)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, "Backend must support bfloat16")
|
||||
def test_dequantization_bf16(self): self._test_dequantization(GGMLQuantizationType.BF16)
|
||||
def test_dequantization_mxfp4_old(self):
|
||||
def encode(nibbles, E):
|
||||
|
|
@ -229,7 +229,7 @@ class TestGGUFGEMV(unittest.TestCase):
|
|||
x = rng.standard_normal(cols).astype(np.float32)
|
||||
with np.errstate(all='ignore'):
|
||||
np.testing.assert_allclose((tensors["weight"] @ Tensor(x)).numpy(), ref @ x, atol=1e-2, rtol=1e-2)
|
||||
if qtype == GGMLQuantizationType.BF16 or is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref)
|
||||
if qtype == GGMLQuantizationType.BF16 or dtypes.half in supported_dtypes: np.testing.assert_equal(tensors["weight"].numpy(), ref)
|
||||
assert np.isfinite(ref).all() and np.isfinite(tensors["weight"].numpy()).all(), f"{qtype.name} has NaN/Inf"
|
||||
|
||||
def test_gguf_gemv_q8_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q8_0)
|
||||
|
|
@ -243,7 +243,7 @@ class TestGGUFGEMV(unittest.TestCase):
|
|||
def test_gguf_gemv_iq2_s(self): self._test_gguf_gemv(GGMLQuantizationType.IQ2_S)
|
||||
def test_gguf_gemv_iq4_xs(self): self._test_gguf_gemv(GGMLQuantizationType.IQ4_XS)
|
||||
def test_gguf_gemv_mxfp4(self): self._test_gguf_gemv(GGMLQuantizationType.MXFP4)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16")
|
||||
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, "Backend must support bfloat16")
|
||||
def test_gguf_gemv_bf16(self): self._test_gguf_gemv(GGMLQuantizationType.BF16)
|
||||
|
||||
class TestGGUFGC(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ import hashlib, random, unittest
|
|||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import DEV
|
||||
from test.helpers import slow
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
|
||||
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
|
||||
|
||||
@unittest.skipUnless(dtypes.uint8 in supported_dtypes and dtypes.uint64 in supported_dtypes, "Device must support uint8 and uint64")
|
||||
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV", "crashes in NV CI")
|
||||
class TestHashing(unittest.TestCase):
|
||||
def _python_hash_1mb(self, data:bytes):
|
||||
|
|
@ -21,7 +22,7 @@ class TestHashing(unittest.TestCase):
|
|||
out = Tensor(b"abc").hash()
|
||||
self.assertEqual(bytes(out.data()), expected)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
|
||||
@unittest.skipUnless(dtypes.uint8 in supported_dtypes and dtypes.uint64 in supported_dtypes, "Device must support uint8 and uint64")
|
||||
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV", "crashes in NV CI")
|
||||
class TestKeccak(unittest.TestCase):
|
||||
def setUp(self) -> None: random.seed(1337)
|
||||
|
|
|
|||
|
|
@ -89,8 +89,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
supported_ops = tuple(ren.code_for_op.keys())
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.target, name="decompositions")
|
||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.target), name="decomp dtypes")
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
|
||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||
|
||||
# move gates from unrenderable INVALID where
|
||||
|
|
@ -99,7 +99,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# final rules for the renderer (without sym)
|
||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.target, name="final rewrite")
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,11 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator, Generator, TYPE_CHECKING
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import BENCHMARKS, CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
|
||||
import importlib, inspect, functools, pathlib, os, contextlib, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
|
||||
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, suppress_finalizing
|
||||
from tinygrad.helpers import select_by_name, select_first_inited, DEV, EMULATED_DTYPES, IMAGE, FLOAT16, TracingKey, size_to_str, Target
|
||||
from tinygrad.helpers import pluralize
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.helpers import select_by_name, select_first_inited, DEV, TracingKey, size_to_str, pluralize
|
||||
from tinygrad.dtype import DType, PtrDType, _to_np_dtype
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
|
|
@ -336,47 +335,6 @@ class Compiled:
|
|||
"""
|
||||
# override this in your device implementation
|
||||
|
||||
# TODO: move this to each Device
|
||||
# this only tracks if the dtype is natively supported, it may be supported in the frontend using decomps
|
||||
def is_dtype_supported(dtype:DType, target:Target|None=None) -> bool:
|
||||
target = target or DEV.target(Device.DEFAULT)
|
||||
if dtype == dtypes.bfloat16:
|
||||
match target.device:
|
||||
case "METAL": target.arch.startswith("Apple") and int(target.arch[5:]) >= 6
|
||||
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
|
||||
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
|
||||
case "CPU": return platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and target.renderer not in ("LVP", "X86")
|
||||
case "AMD" | "CL" | "PYTHON" | "NULL": return True
|
||||
case _: return False
|
||||
if dtype in dtypes.fp8_ocp:
|
||||
match target.device:
|
||||
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
|
||||
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
|
||||
case "AMD": return (not CI or BENCHMARKS) and target.arch == "gfx950"
|
||||
case "PYTHON" | "NULL": return True
|
||||
case _: return False
|
||||
if dtype in dtypes.fp8_fnuz: return target.device in {"PYTHON", "NULL"}
|
||||
if target.device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
# for CI LLVM, it segfaults because it can't link to the casting function
|
||||
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
||||
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
|
||||
if dtype == dtypes.half:
|
||||
match target.device:
|
||||
case "CL": return "cl_khr_fp16" in target.arch
|
||||
case "QCOM": return bool(IMAGE) and bool(FLOAT16) # QCOM compiler is flaky with half
|
||||
case "CUDA" | "NV": return not CI or BENCHMARKS or target.renderer == "PYTHON"
|
||||
case "CPU" if target.renderer == "LLVM": return OSX
|
||||
case "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64:
|
||||
match target.device:
|
||||
case _ if dtypes.long in EMULATED_DTYPES.tolist(dtypes): return False # double can't be bitcast to anything without long support
|
||||
case "CL": return "cl_khr_fp64" in target.arch
|
||||
case "NULL": return target.renderer not in ("IR3", "QCOMCL")
|
||||
case "METAL" | "QCOM": return False
|
||||
return True
|
||||
|
||||
if PROFILE:
|
||||
@atexit.register
|
||||
def finalize_profile():
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
|
|||
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX, WIN = platform.system() == "Darwin", sys.platform == "win32"
|
||||
CI, BENCHMARKS = os.getenv("CI", "") != "", os.getenv("RUNNER_ENVIRONMENT", "") == "self-hosted"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64"))
|
||||
BASEDIR = pathlib.Path(__file__).parent
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, cast
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, Target
|
||||
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
|
|
@ -85,3 +85,6 @@ class Renderer:
|
|||
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
def asm(self, prg:UOp, lin:UOp) -> bytes: raise NotImplementedError("needs an assembler")
|
||||
def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux")
|
||||
def supported_dtypes(self) -> set[DType]:
|
||||
# double can't be bitcast to anything without long support
|
||||
return set(dtypes.all) - {dtypes.weakint} - ({dtypes.double} if dtypes.long in EMULATED_DTYPES.tolist(dtypes) else set())
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import math, sys, struct
|
|||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, Target, CPU_COUNT
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, Target, CPU_COUNT, IMAGE, FLOAT16
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
|
|
@ -272,6 +272,9 @@ class ClangRenderer(CStyleLanguage):
|
|||
defines = '\n'.join(self._render_defines(uops))
|
||||
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
|
||||
|
||||
def supported_dtypes(self):
|
||||
return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or self.target.arch.startswith(("x86", "arm"))) and d not in dtypes.fp8s}
|
||||
|
||||
class ClangJITRenderer(ClangRenderer):
|
||||
def __init__(self, target:Target):
|
||||
super().__init__(target)
|
||||
|
|
@ -321,6 +324,10 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
arg_dtypes[u.arg].append((i, u.dtype))
|
||||
return tuple(tuple(a) for a in arg_dtypes),
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes()
|
||||
if (d != dtypes.half or "cl_khr_fp16" in self.target.arch) and
|
||||
(d != dtypes.double or "cl_khr_fp64" in self.target.arch) and d not in dtypes.fp8s}
|
||||
|
||||
class IntelRenderer(OpenCLRenderer):
|
||||
suffix, kernel_typedef = "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
|
||||
tensor_cores = tc.intel
|
||||
|
|
@ -382,6 +389,10 @@ class MetalRenderer(CStyleLanguage):
|
|||
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def supported_dtypes(self):
|
||||
return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or ((arch:=self.target.arch).startswith("Apple") and int(arch[5:]) >= 6))
|
||||
and d not in dtypes.fp8s+(dtypes.double,)}
|
||||
|
||||
_nms = list("xyzwabcdefghijkl") + [f'v{i}' for i in range(16, 32)]
|
||||
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
|
|
@ -456,6 +467,11 @@ class CUDARenderer(CStyleLanguage):
|
|||
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
|
||||
def supported_dtypes(self):
|
||||
ver = int(self.target.arch[3:])
|
||||
return {d for d in super().supported_dtypes() if (d != dtypes.half or ver >= 53) and (d != dtypes.bfloat16 or ver >= 80)
|
||||
and (d not in dtypes.fp8_ocp or ver >= 89) and d not in dtypes.fp8_fnuz}
|
||||
|
||||
class NVCCRenderer(CUDARenderer):
|
||||
def __init__(self, target:Target): super().__init__(target, use_nvcc=True)
|
||||
|
||||
|
|
@ -559,6 +575,9 @@ class HIPRenderer(CStyleLanguage):
|
|||
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes()
|
||||
if (d not in dtypes.fp8_ocp or self.target.arch == "gfx950") and d not in dtypes.fp8_fnuz}
|
||||
|
||||
class HIPCCRenderer(HIPRenderer):
|
||||
def __init__(self, target:Target): super().__init__(target, use_hipcc=True)
|
||||
|
||||
|
|
@ -567,3 +586,8 @@ class QCOMCLRenderer(OpenCLRenderer):
|
|||
super().__init__(target)
|
||||
from tinygrad.runtime.support.compiler_qcom import QCOMCompiler
|
||||
self.compiler = QCOMCompiler(target.arch)
|
||||
|
||||
# QCOM compiler is flaky with half
|
||||
def supported_dtypes(self):
|
||||
return {d for d in Renderer.supported_dtypes(self)
|
||||
if (d != dtypes.float16 or (bool(IMAGE) and bool(FLOAT16))) and d not in dtypes.fp8s+(dtypes.bfloat16,dtypes.double)}
|
||||
|
|
|
|||
|
|
@ -894,3 +894,5 @@ class X86Renderer(ISARenderer):
|
|||
for u in uops:
|
||||
if (t:=jumps.get(u)) is not None: binary[t-4:t] = (targets[u.tag] - t).to_bytes(4, 'little', signed=True)
|
||||
return binary.hex()
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d not in dtypes.fp8s+(dtypes.bfloat16,)}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats,
|
|||
from tinygrad.uop.decompositions import xexp2, xlog2
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
|
||||
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
|
||||
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv
|
||||
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX
|
||||
|
||||
def ldt(dt:DType):
|
||||
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
||||
|
|
@ -206,6 +206,11 @@ class CPULLVMRenderer(LLVMRenderer):
|
|||
if "AMX" in target.arch: self.tensor_cores = tc.amx
|
||||
self.compiler = CPULLVMCompiler([x for x in target.arch.split(",") if x != "AMX"])
|
||||
|
||||
# FIXME: fp16 works on non-osx, but only if the cpu supports it
|
||||
def supported_dtypes(self):
|
||||
return {d for d in super().supported_dtypes() if
|
||||
(d != dtypes.bfloat16 or self.target.arch.startswith(("x86", "arm"))) and (d != dtypes.half or OSX) and d not in dtypes.fp8s}
|
||||
|
||||
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
|
||||
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
|
||||
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
|
||||
|
|
@ -291,3 +296,6 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc
|
|||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
|
||||
])
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes()
|
||||
if (d not in dtypes.fp8_ocp or self.target.arch == "gfx950") and d not in dtypes.fp8_fnuz}
|
||||
|
|
|
|||
|
|
@ -226,11 +226,15 @@ class NIRRenderer(Renderer):
|
|||
|
||||
return ret
|
||||
|
||||
def supported_dtypes(self): return {d for d in Renderer.supported_dtypes(self) if d not in dtypes.fp8s+(dtypes.bfloat16,)}
|
||||
|
||||
class NAKRenderer(NIRRenderer):
|
||||
param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz),
|
||||
intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])(
|
||||
lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv))
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes() if (d != dtypes.half or int(self.target.arch[3:]) >= 53)}
|
||||
|
||||
class LVPRenderer(NIRRenderer):
|
||||
has_local = False
|
||||
has_shared = False
|
||||
|
|
@ -294,3 +298,5 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
|
||||
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])
|
||||
self.b.shader.contents.info.num_images = texs() + imgs()
|
||||
|
||||
def supported_dtypes(self): return {d for d in NIRRenderer.supported_dtypes(self) if d != dtypes.double}
|
||||
|
|
|
|||
|
|
@ -239,3 +239,6 @@ class PTXRenderer(Renderer):
|
|||
|
||||
if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg};"] + kernel
|
||||
return self.render_kernel(kernel, name, bufs, c.items(), uops)
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes()
|
||||
if (d != dtypes.half or int(self.target.arch[3:]) >= 53) and d not in dtypes.fp8s+(dtypes.bfloat16,)}
|
||||
|
|
|
|||
|
|
@ -112,3 +112,6 @@ class WGSLRenderer(CStyleLanguage):
|
|||
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
def supported_dtypes(self):
|
||||
return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ class DSPRenderer(ClangRenderer):
|
|||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d not in dtypes.fp8s+(dtypes.bfloat16,)}
|
||||
|
||||
def rpc_sc(method=0, ins=0, outs=0, fds=0): return (method << 24) | (ins << 16) | (outs << 8) | fds
|
||||
def rpc_prep_args(ins=None, outs=None, in_fds=None):
|
||||
ins, outs, in_fds = ins or list(), outs or list(), in_fds or list()
|
||||
|
|
|
|||
|
|
@ -229,6 +229,8 @@ class PythonRenderer(Renderer):
|
|||
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d != dtypes.half or sys.version_info >= (3, 12)}
|
||||
|
||||
class PythonAllocator(Allocator['PythonDevice']):
|
||||
def _alloc(self, size, options): return memoryview(bytearray(size))
|
||||
def _copyin(self, dest, src:memoryview): dest[:] = src
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Callable
|
||||
import math, functools
|
||||
from tinygrad.dtype import dtypes, DType, promo_lattice, truncate
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import flatten, polyN, Target, EMULATED_DTYPES
|
||||
from tinygrad.helpers import flatten, polyN, EMULATED_DTYPES
|
||||
from tinygrad.uop import GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
||||
|
|
@ -279,9 +279,10 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
|||
return m, s
|
||||
assert False
|
||||
|
||||
def fast_idiv(target: Target, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||
def fast_idiv(ren: Renderer, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
# NOTE: disable for METAL due to compiler bug. keccak with -O0 works but not with optimization
|
||||
if target.device.startswith("METAL"): return None
|
||||
if isinstance(ren, MetalRenderer): return None
|
||||
# If d is a power of two this is not valid for signed ints!
|
||||
is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints
|
||||
assert d>0, "Sign should have been taken out of divisor"
|
||||
|
|
@ -293,11 +294,11 @@ def fast_idiv(target: Target, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
|||
# before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller
|
||||
# use explicit Ops.CDIV (trunc) since the recursion assumes trunc semantics throughout
|
||||
if (largest_factor_of_two_in_d := (d & -d)) > 1:
|
||||
if (ret:=fast_idiv(target, x.alu(Ops.CDIV, x.const_like(largest_factor_of_two_in_d)),
|
||||
if (ret:=fast_idiv(ren, x.alu(Ops.CDIV, x.const_like(largest_factor_of_two_in_d)),
|
||||
d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
|
||||
if dont_cast: return None
|
||||
# promo_lattice needs to return an unsigned type if the type is unsigned
|
||||
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, target):
|
||||
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and next_dtype in ren.supported_dtypes():
|
||||
if m*vmin >= next_dtype.min and m*vmax <= next_dtype.max:
|
||||
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
|
||||
return None
|
||||
|
|
@ -553,8 +554,8 @@ pm_float_decomp = PatternMatcher([
|
|||
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
|
||||
])
|
||||
|
||||
def do_dtype_decomps(sink:UOp, ctx:tuple[set[DType], Target]) -> UOp:
|
||||
def _should_emulate(dt): return dt in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dt, ctx[1])
|
||||
def do_dtype_decomps(sink:UOp, ctx:tuple[set[DType], Renderer]) -> UOp:
|
||||
def _should_emulate(dt): return dt in EMULATED_DTYPES.tolist(dtypes) or dt not in ctx[1].supported_dtypes()
|
||||
for fr in sorted(filter(_should_emulate, ctx[0])):
|
||||
if fr in dtypes.floats:
|
||||
to = dtypes.half if not _should_emulate(dtypes.half) and fr in dtypes.fp8s else dtypes.float
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue