move is_dtype_supported to renderer (#16226)

This commit is contained in:
Christopher Milan 2026-05-20 18:19:37 -07:00 committed by GitHub
commit 172f9493e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 211 additions and 245 deletions

View file

@ -3,7 +3,7 @@ os.environ["WQKV"] = "1"
import unittest
import numpy as np
from tinygrad import Tensor, nn, dtypes
from tinygrad.device import is_dtype_supported, Device
from tinygrad.device import Device
from examples.mlperf.models.llama import Transformer
from examples.mlperf.models.flat_llama import FlatTransformer
@ -111,7 +111,7 @@ class TestFlatLlama(unittest.TestCase):
self.assertEqual(ref_logits.shape, flat_logits.shape)
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8 not supported on this device")
@unittest.skipUnless(dtypes.fp8e4m3 in Device[Device.DEFAULT].renderer.supported_dtypes(), "fp8 not supported on this device")
def test_forward_fp8(self):
import examples.mlperf.models.flat_llama as flat_llama_mod
old_fp8 = flat_llama_mod.FP8

View file

@ -1,5 +1,4 @@
from tinygrad import Tensor, dtypes, nn
from tinygrad.device import is_dtype_supported
from tinygrad import Tensor, Device, dtypes, nn
from typing import Optional, Union, List, Any, Tuple, Callable
import math
@ -13,7 +12,7 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
return out.cast(mixed_precision_dtype) if mixed_precision_dtype in Device[Device.DEFAULT].renderer.supported_dtypes() else out
class ResBlock:
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
@ -238,7 +237,7 @@ class UNetModel:
assert y.shape[0] == x.shape[0]
emb = emb + y.sequential(self.label_emb[0])
if is_dtype_supported(mixed_precision_dtype):
if mixed_precision_dtype in Device[Device.DEFAULT].renderer.supported_dtypes():
emb = emb.cast(mixed_precision_dtype)
ctx = ctx.cast(mixed_precision_dtype)
x = x .cast(mixed_precision_dtype)

View file

@ -1,6 +1,5 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, system, DEV
from extra.gemm.cdna_asm_gemm import asm_gemm
from test.helpers import needs_second_gpu
@ -85,7 +84,7 @@ def verify_asm_gemm_k_sharded_3d(batch:int, M:int, N:int, K:int, dtype=dtypes.fl
# 128x smaller than usual
# uses the UOp GEMM, runs on non CDNA4 and CI
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "need half")
class TestGemm(unittest.TestCase):
def setUp(self):
if is_cdna4(): self.skipTest("shapes are too small for the assembly GEMM")

View file

@ -2,7 +2,6 @@ import unittest, math
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DTYPES_DICT
from tinygrad.uop.ops import Ops, UOp
from tinygrad.device import is_dtype_supported
import numpy as np
from test.helpers import not_support_multi_device
@ -39,12 +38,10 @@ class TestMovedConstFolding(unittest.TestCase):
def test_cast_padded(self):
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count
if is_dtype_supported(dtypes.int16):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.uint16):
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# folded
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
@ -114,7 +111,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
def test_sum_output_dtype(self):
# sum output dtype can be different from input
for dt in DTYPES_DICT.values():
if is_dtype_supported(dt):
if dt in Device[Device.DEFAULT].renderer.supported_dtypes():
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
assert t.sum().dtype == t.contiguous().sum().dtype

View file

@ -2,7 +2,6 @@ import contextlib, unittest, math
import numpy as np
import torch
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI, EMULATED_DTYPES
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
from tinygrad.renderer.ptx import PTXRenderer
@ -17,11 +16,13 @@ pytestmark = pytest.mark.filterwarnings("ignore")
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) or v in dtypes.fp8s+(dtypes.half,dtypes.bfloat16,dtypes.long)]
if dtype in (dtypes.long, dtypes.ulong) and (not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes)):
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and v in supported_dtypes or v in dtypes.fp8s+(dtypes.half,dtypes.bfloat16,dtypes.long)]
if dtype in (dtypes.long, dtypes.ulong) and (dtype not in supported_dtypes or dtypes.long in EMULATED_DTYPES.tolist(dtypes)):
return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support
if not is_dtype_supported(dtype) and dtype not in dtypes.fp8s+(dtypes.half,dtypes.bfloat16): return []
if dtype not in supported_dtypes and dtype not in dtypes.fp8s+(dtypes.half,dtypes.bfloat16): return []
return dts
def _to_torch_storage_type(dtype:DType):
@ -60,7 +61,7 @@ class TestDType(unittest.TestCase):
@classmethod
def setUpClass(cls):
if cls.DTYPE is None: raise unittest.SkipTest("base class")
cls.DATA = rand_for_dtype(cls.DTYPE, 0x10, allow_subnormal=is_dtype_supported(cls.DTYPE))
cls.DATA = rand_for_dtype(cls.DTYPE, 0x10, allow_subnormal=cls.DTYPE in supported_dtypes)
def test_to_np(self):
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
@ -115,7 +116,7 @@ class TestDType(unittest.TestCase):
for dt in dtypes:
arr = np.asarray(data).astype(dt)
tensor = Tensor(arr)
if not is_dtype_supported(tensor.dtype): continue
if tensor.dtype not in supported_dtypes: continue
tin = tensor.numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
@ -271,7 +272,7 @@ class TestFloatDType(TestDType):
_test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32,
[0, 0, 1])
@unittest.skipUnless(is_dtype_supported(dtypes.double), f"no double on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.double in supported_dtypes, f"no double on {Device.DEFAULT}")
class TestDoubleDType(TestDType):
DTYPE = dtypes.double
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \
@ -478,7 +479,7 @@ class TestImplicitFunctionTypeChange(unittest.TestCase):
class TestTensorMethod(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_abs_diff(self, dt):
if dt == dtypes.bool or not is_dtype_supported(dt): return
if dt == dtypes.bool or dt not in supported_dtypes: return
a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
ret = (a - b).abs()
np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))
@ -486,11 +487,11 @@ class TestTensorMethod(unittest.TestCase):
class TestDtypeUsage(unittest.TestCase):
def test_max_w_alu(self):
for d in dtypes.ints:
if is_dtype_supported(d):
if d in supported_dtypes:
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
class TestOpsBFloat16(unittest.TestCase):
def test_cast(self):
# TODO: helper_test_op breaks in unrelated part

View file

@ -3,7 +3,6 @@ from tinygrad import Context, Tensor, dtypes, Device
from tinygrad.dtype import DType, truncate, fp8_to_float
from tinygrad.helpers import CI, EMULATED_DTYPES, DEV, getenv
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
from tinygrad.runtime.ops_python import from_storage_scalar
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
@ -41,6 +40,8 @@ if ((DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"})
# transcendental isn't accurate enough
if Ops.SQRT not in Device[Device.DEFAULT].renderer.code_for_op: unary_operations.remove((Tensor.sqrt, np.sqrt))
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
class ht:
float64 = strat.floats(width=64, allow_subnormal=False)
float32 = strat.floats(width=32, allow_subnormal=False)
@ -71,7 +72,7 @@ def universal_test(a, b, dtype, op):
numpy_value = truncate[dtype](op[1](ta.numpy(), tb.numpy()).item())
else: tensor_value, numpy_value = (op[0](ta, tb)).numpy(), op[1](ta.numpy(), tb.numpy())
if dtype in dtypes.floats:
if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
if dtype not in supported_dtypes or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
fe, fm = dtypes.finfo(dtype)
atol, rtol = 2 ** (2 - (1 << (fe - 1))), 2 ** (-fm)
else: atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1),
@ -87,7 +88,8 @@ def universal_test_unary(a, dtype, op):
if op[0] == Tensor.log and a <= 0: return
if dtype in dtypes.fp8s:
# denormals are zero
if dtype in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dtype) and abs(ta.numpy().item()) < 0.015625: return
if (dtype in EMULATED_DTYPES.tolist(dtypes) or dtype not in supported_dtypes
and abs(ta.numpy().item()) < 0.015625): return
tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype)
numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item())
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
@ -119,14 +121,14 @@ def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) else 1e-7)
class TestDTypeALU(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.float64), f"no float64 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float64 in supported_dtypes, f"no float64 on {Device.DEFAULT}")
@given(ht.float64, ht.float64, strat.sampled_from(binary_operations))
def test_float64(self, a, b, op): universal_test(a, b, dtypes.float64, op)
@given(ht.float32, ht.float32, strat.sampled_from(binary_operations))
def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
@given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
@ -134,7 +136,7 @@ class TestDTypeALU(unittest.TestCase):
@Context(EMULATED_DTYPES="half")
def test_emulated_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
def test_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@ -144,7 +146,7 @@ class TestDTypeALU(unittest.TestCase):
def test_emulated_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e4m3 in supported_dtypes, f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
def test_fp8e4m3(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@ -154,7 +156,7 @@ class TestDTypeALU(unittest.TestCase):
def test_emulated_fp8e4m3(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e5m2 in supported_dtypes, f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
def test_fp8e5m2(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@ -164,12 +166,12 @@ class TestDTypeALU(unittest.TestCase):
def test_emulated_fp8e5m2(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3fnuz), f"no fp8e4m3fnuz on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e4m3fnuz in supported_dtypes, f"no fp8e4m3fnuz on {Device.DEFAULT}")
@given(ht.fp8e4m3fnuz, ht.fp8e4m3fnuz, strat.sampled_from(binary_operations))
def test_fp8e4m3fnuz(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3fnuz), from_storage_scalar(b, dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2fnuz), f"no fp8e5m2fnuz on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e5m2fnuz in supported_dtypes, f"no fp8e5m2fnuz on {Device.DEFAULT}")
@given(ht.fp8e5m2fnuz, ht.fp8e5m2fnuz, strat.sampled_from(binary_operations))
def test_fp8e5m2fnuz(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2fnuz), from_storage_scalar(b, dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
@ -187,7 +189,7 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
@given(ht.float16, strat.sampled_from(unary_operations))
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@ -195,7 +197,7 @@ class TestDTypeALU(unittest.TestCase):
@Context(EMULATED_DTYPES="half")
def test_emulated_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
@given(ht.bfloat16, strat.sampled_from(unary_operations))
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@ -203,7 +205,7 @@ class TestDTypeALU(unittest.TestCase):
@Context(EMULATED_DTYPES="bfloat16")
def test_emulated_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e4m3 in supported_dtypes, f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
def test_fp8e4m3_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
@ -215,7 +217,7 @@ class TestDTypeALU(unittest.TestCase):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e5m2 in supported_dtypes, f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
def test_fp8e5m2_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
@ -227,13 +229,13 @@ class TestDTypeALU(unittest.TestCase):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3fnuz), f"no fp8e4m3fnuz on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e4m3fnuz in supported_dtypes, f"no fp8e4m3fnuz on {Device.DEFAULT}")
@given(ht.fp8e4m3fnuz, strat.sampled_from(unary_operations))
def test_fp8e4m3fnuz_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2fnuz), f"no fp8e5m2fnuz on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e5m2fnuz in supported_dtypes, f"no fp8e5m2fnuz on {Device.DEFAULT}")
@given(ht.fp8e5m2fnuz, strat.sampled_from(unary_operations))
def test_fp8e5m2fnuz_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz) != 0.0)
@ -254,15 +256,15 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint16), f"no uint16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.uint16 in supported_dtypes, f"no uint16 on {Device.DEFAULT}")
@given(ht.uint16, ht.uint16, strat.sampled_from(integer_binary_operations))
def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint32), f"no uint32 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.uint32 in supported_dtypes, f"no uint32 on {Device.DEFAULT}")
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.uint64 in supported_dtypes, f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
@ -291,15 +293,15 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint8, strat.sampled_from(integer_unary_operations))
def test_uint8_unary(self, a, op): universal_test_unary(a, dtypes.uint8, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint16), f"no uint16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.uint16 in supported_dtypes, f"no uint16 on {Device.DEFAULT}")
@given(ht.uint16, strat.sampled_from(integer_unary_operations))
def test_uint16_unary(self, a, op): universal_test_unary(a, dtypes.uint16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint32), f"no uint32 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.uint32 in supported_dtypes, f"no uint32 on {Device.DEFAULT}")
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.uint64 in supported_dtypes, f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
@ -352,21 +354,21 @@ class TestDTypeALU(unittest.TestCase):
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
if float_dtype not in supported_dtypes: float_dtype = dtypes.float32
universal_test_cast(a, float_dtype, unsigned_dtype)
@unittest.skip("relied on hacks")
@given(strat.floats(width=32, min_value=256.0, max_value=65000.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
if float_dtype not in supported_dtypes: float_dtype = dtypes.float32
universal_test_cast(a, float_dtype, unsigned_dtype)
@unittest.skip("relied on hacks")
@given(strat.floats(width=32, min_value=-65000.0, max_value=-1.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype): float_dtype = dtypes.float32
if float_dtype not in supported_dtypes: float_dtype = dtypes.float32
universal_test_cast(a, float_dtype, unsigned_dtype)
@unittest.expectedFailure

View file

@ -3,7 +3,7 @@ import unittest
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, buffers
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.device import Device, Buffer
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
@ -206,7 +206,7 @@ class TestLinearizer(unittest.TestCase):
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
if tensor_dtype in (dts:=Device[Device.DEFAULT].renderer.supported_dtypes()) and acc_dtype in dts:
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
@ -229,7 +229,7 @@ class TestLinearizer(unittest.TestCase):
(dtypes.float, dtypes.float16, dtypes.float16),
)
for tensor_dtype, acc_dtype, expected_dtype in tests:
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype):
if tensor_dtype in (dts:=Device[Device.DEFAULT].renderer.supported_dtypes()) and acc_dtype in dts and expected_dtype in dts:
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype)

View file

@ -1,6 +1,5 @@
import unittest
from tinygrad import Tensor, dtypes, Context
from tinygrad.device import is_dtype_supported
from tinygrad import Tensor, Device, dtypes, Context
from extra.llama_kernels.fused_ce import fused_ce_loss
def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> None:
@ -24,7 +23,7 @@ def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> No
assert loss.allclose(ref, atol=2e-3, rtol=2e-3).item(), "forward mismatch"
assert logits.grad.allclose(logits_ref.grad, atol=2e-3, rtol=2e-3).item(), "grad mismatch"
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16")
@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16")
class TestFusedCE(unittest.TestCase):
def test_fused_ce_1_2_16(self): run_fused_ce(1, 2, 16, label_smoothing=0.2)
def test_fused_ce_2_16_128(self): run_fused_ce(2, 16, 128)

View file

@ -1,6 +1,5 @@
import unittest, random
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
@ -911,7 +910,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
def test_ops(self, dtype):
if not is_dtype_supported(dtype): return
if dtype not in Device[Device.DEFAULT].renderer.supported_dtypes(): return
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
for i in range(4):

View file

@ -5,7 +5,6 @@ import torch
from tinygrad.helpers import getenv, CI, DEBUG, DEV, IMAGE, Context
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
from tinygrad.renderer.cstyle import QCOMCLRenderer
from tinygrad.renderer.nir import NIRRenderer
@ -1508,7 +1507,8 @@ class TestOps(unittest.TestCase):
def test_sum_dtype_arg(self):
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32))
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64))
if dtypes.float64 in Device[Device.DEFAULT].renderer.supported_dtypes():
helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64))
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(dtype="")
@ -3379,7 +3379,6 @@ class TestOps(unittest.TestCase):
t = (Tensor([0], dtype='int') | 0xFFFFFFFF).item()
if not COMPILE_ONLY: assert t == -1
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase):
def test_cast(self):
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True, low=0, high=255)

View file

@ -3,7 +3,6 @@ import torch
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
from tinygrad.device import is_dtype_supported
from test.helpers import needs_second_gpu, slow
np.random.seed(1337)
@ -142,7 +141,7 @@ class TestOptim(unittest.TestCase):
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "need half")
def test_mixed_precision(self):
old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
# weight update would overflow without upcasting

View file

@ -3,7 +3,6 @@ from functools import partial
from tinygrad import nn, dtypes, Tensor, Device, TinyJit, Variable
from tinygrad.helpers import getenv, CI, OSX
from tinygrad.device import is_dtype_supported
from tinygrad.codegen import to_program
from tinygrad.uop.ops import Ops
@ -86,7 +85,7 @@ class TestRandomness(unittest.TestCase):
self.assertTrue(r1.uop.is_realized, "tensor should be realized after .realize()")
self.assertTrue(r2.uop.is_realized, "tensor should be realized after .realize()")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
@unittest.skipUnless(dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need float16 support")
def test_rand_float16(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.float16)
@ -211,7 +210,7 @@ class TestRandomness(unittest.TestCase):
if not (x.src[0] == y.src[0]):
print(f"{x.src[0]} != {y.src[0]}")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16 support")
def test_rand_bfloat16(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
@ -285,7 +284,7 @@ class TestRandomness(unittest.TestCase):
@given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16]))
def test_randn_finite(self, default_float):
if not is_dtype_supported(default_float): return
if default_float not in Device[Device.DEFAULT].renderer.supported_dtypes(): return
old_default_float = dtypes.default_float
# low precision can result in inf from randn
dtypes.default_float = default_float

View file

@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.device import Device, is_dtype_supported
from tinygrad.device import Device
from tinygrad.dtype import dtypes, ConstType
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
@ -96,7 +96,7 @@ class TestPTXFailures(unittest.TestCase):
ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "need half")
def test_gated_define_acc_with_half_dtype(self):
a = Tensor.randn(32, 32, dtype=dtypes.half).realize()
b = Tensor.randn(34, 32, dtype=dtypes.half).realize()

View file

@ -8,12 +8,13 @@ from typing import cast
from hypothesis import assume, given, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Ops, UPat
from tinygrad.helpers import CI, DEBUG, OSX, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.realize import compile_linear, run_linear
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
if to_prerealize:
@ -105,7 +106,7 @@ class TestSchedule(unittest.TestCase):
run_linear(*check_schedule(a, 1))
self.assertListEqual(a.tolist(), [[15]])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and OSX, "WEBGPU Metal backend is not accurate enough")
def test_expand_buffer_before_cast(self):
a = Tensor.randn(4, 2, 1).realize().permute((1, 0, 2))
@ -715,7 +716,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
# TODO like openpilot with imagef
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
def test_base_change_expand_expand(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.cast(dtypes.half).expand(2, 4, 4)
@ -761,9 +762,9 @@ class TestSchedule(unittest.TestCase):
def test_conv2d(self): _test_conv2d(4)
def test_conv2d_fused(self): _test_conv2d(4)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
def test_conv2d_half(self): _test_conv2d(4, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")
def test_conv2d_fused_half(self): _test_conv2d(4, dtype=dtypes.half)
@ -870,7 +871,7 @@ class TestSchedule(unittest.TestCase):
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
@unittest.skip("kernel count depends on input")
def test_cast_padded_const(self, dt1, dt2):
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
assume(dt1 in supported_dtypes and dt2 in supported_dtypes)
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dt2)
run_linear(*check_schedule(casted_view, 0))
@ -994,7 +995,7 @@ class TestSchedule(unittest.TestCase):
self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
def test_precompute_freqs_cis(self):
from extra.models.llama import precompute_freqs_cis
args = {"dim":32, "end":2048, "theta":10000}

View file

@ -4,7 +4,6 @@ from tinygrad import Tensor, GlobalCounters, Context, Device
from tinygrad.dtype import DTypeLike, dtypes
from tinygrad.engine.realize import run_linear
from tinygrad.helpers import DEBUG, get_single_element
from tinygrad.device import is_dtype_supported
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
# only support axis =-1
@ -62,7 +61,7 @@ class TestFuse(unittest.TestCase):
b = Tensor.rand(50,50).realize()
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes(), f"no float16 on {Device.DEFAULT}")
@unittest.skip("needs RANGEIFY>1")
def test_fuse_softmax_dtype(self):
a = Tensor.rand(50,50).realize()

View file

@ -5,7 +5,6 @@ from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DTYPES_DICT
from tinygrad.uop.ops import UOp
@ -454,7 +453,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
def test_tensor_list_special_values(self):
if is_dtype_supported(dtypes.float16):
if dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes():
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
data = data + [-x for x in data]
with np.errstate(over='ignore'): np.testing.assert_allclose(Tensor(data, dtype=dtypes.float16).numpy(), np.array(data).astype(np.float16))
@ -571,7 +570,7 @@ class TestMoveTensor(unittest.TestCase):
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_to_preserves(self, src, dest, dtype, requires_grad):
if not is_dtype_supported(dtype):
if dtype not in Device[Device.DEFAULT].renderer.supported_dtypes():
return
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
if requires_grad: s.sum().backward()

View file

@ -4,7 +4,6 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import Context, getenv, CI, DEV, OSX
from test.backend.test_schedule import check_schedule
from test.backend.test_dtype_alu import ht, dtypes_float
from tinygrad.device import is_dtype_supported
import numpy as np
import math
from hypothesis import given, settings, strategies as strat
@ -12,8 +11,10 @@ from hypothesis import given, settings, strategies as strat
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
class TestTranscendentalMath(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.float64), f"no float64 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float64 in supported_dtypes, f"no float64 on {Device.DEFAULT}")
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
@given(ht.float64, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
def test_float64(self, x, op):
@ -27,7 +28,7 @@ class TestTranscendentalMath(unittest.TestCase):
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
([(Tensor.sin, np.sin)] if dtypes.ulong in supported_dtypes else [])))
def test_float32(self, x, op):
# wrong nan behavior on Vulkan
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return
@ -36,9 +37,9 @@ class TestTranscendentalMath(unittest.TestCase):
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float32))),
atol=2e-5, rtol=1e-5)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
@given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
([(Tensor.sin, np.sin)] if dtypes.ulong in supported_dtypes else [])))
def test_float16(self, x, op):
# wrong nan behavior on Vulkan
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return
@ -53,7 +54,7 @@ class TestTranscendentalMath(unittest.TestCase):
def test_exp_near_inf(self, dtype_x):
# reordering compute might return inf
dtype, x = dtype_x
if not is_dtype_supported(dtype): return
if dtype not in supported_dtypes: return
with Context(TRANSCENDENTAL=2):
y = Tensor([x], dtype=dtype).exp().numpy()
expected = np.exp(np.array([x], dtype=_to_np_dtype(dtype)))
@ -61,9 +62,9 @@ class TestTranscendentalMath(unittest.TestCase):
class TestFromFuzzer(unittest.TestCase):
@given(strat.sampled_from(dtypes_float))
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
@unittest.skipUnless(dtypes.ulong in supported_dtypes, "Needs ulong")
def test_sin(self, dtype):
if not is_dtype_supported(dtype): return
if dtype not in supported_dtypes: return
if dtype == dtypes.float64:
# crashes in CI CUDA
if DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}: return
@ -85,7 +86,7 @@ class TestFromFuzzer(unittest.TestCase):
@given(strat.sampled_from(dtypes_float))
def test_log2(self, dtype):
if not is_dtype_supported(dtype): return
if dtype not in supported_dtypes: return
if dtype == dtypes.float64:
# crashes in CI CUDA
if DEV.interface.startswith("MOCK") and Device.DEFAULT in {"NV", "CUDA"}: return
@ -104,7 +105,7 @@ class TestFromFuzzer(unittest.TestCase):
class TestFloat16Log2(unittest.TestCase):
"""Tests for native float16 log2 implementation (no float32 cast)"""
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
def test_float16_log2_basic(self):
# basic values
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
@ -114,7 +115,7 @@ class TestFloat16Log2(unittest.TestCase):
expected = np.log2(np.float16(val))
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
def test_float16_log2_special(self):
# special values: inf, -inf, nan, 0, negative
@ -128,7 +129,7 @@ class TestFloat16Log2(unittest.TestCase):
# log2(nan) = nan
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
@unittest.skipUnless(is_dtype_supported(dtypes.float16), f"no float16 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, f"no float16 on {Device.DEFAULT}")
def test_float16_log2_denormal(self):
# test values near and below float16 min normal (6.1e-5)
# these exercise the denormal handling path with 2^10 scaling
@ -141,7 +142,7 @@ class TestFloat16Log2(unittest.TestCase):
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
class TestTranscendentalSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
@unittest.skipUnless(dtypes.ulong in supported_dtypes, "Needs ulong")
def test_transcendental_sin_fusion(self):
with Context(TRANSCENDENTAL=2):
a = Tensor.empty(10)

View file

@ -9,7 +9,6 @@ from tinygrad.uop.ops import Ops, UOp, KernelInfo, AxisType, buffers
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import to_uops_list
@ -141,9 +140,7 @@ class TestNonFloatUOps(TestUOps):
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
def test_cmpne_int32(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(Ops.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
def test_where_float16(self):
self._test_top_fxn(Ops.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))

View file

@ -5,7 +5,6 @@ import onnx.backend.test
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import getenv, OSX
from tinygrad.device import is_dtype_supported
from tinygrad.nn.onnx import OnnxRunner
# pip3 install tabulate
@ -70,7 +69,7 @@ backend_test.exclude('test_resize_downsample_scales_linear_align_corners_cpu')
backend_test.exclude('test_resize_downsample_scales_cubic_align_corners_cpu')
# about different dtypes
if not is_dtype_supported(dtypes.float64):
if dtypes.float64 not in Device[Device.DEFAULT].renderer.supported_dtypes():
backend_test.exclude('float64')
backend_test.exclude('DOUBLE')
# these have float64 inputs
@ -80,7 +79,7 @@ if not is_dtype_supported(dtypes.float64):
backend_test.exclude('test_einsum_*')
backend_test.exclude('test_cumsum_*')
if not is_dtype_supported(dtypes.float16):
if dtypes.float16 not in Device[Device.DEFAULT].renderer.supported_dtypes():
backend_test.exclude('float16')
backend_test.exclude('FLOAT16')

View file

@ -2,7 +2,6 @@ import unittest, onnx, tempfile, pathlib
import numpy as np
from tinygrad import Tensor
from tinygrad.uop.ops import Ops
from tinygrad.device import is_dtype_supported
from typing import Any
from tinygrad.nn.onnx import OnnxRunner, OnnxPBParser, OnnxDataType
from hypothesis import given, strategies as st
@ -89,7 +88,6 @@ class TestOnnxRunner(unittest.TestCase):
np.testing.assert_equal(output.numpy(), weights + 1)
all_dtypes = list(OnnxDataType)
device_supported_dtypes = {odt for odt in OnnxDataType if is_dtype_supported(odt.to_dtype())}
class TestOnnxRunnerDtypes(unittest.TestCase):
"""

View file

@ -1,7 +1,6 @@
import random
import z3
from tinygrad import dtypes, Device
from tinygrad.helpers import DEV
from tinygrad.uop.validate import uops_to_z3, z3_cdiv
from tinygrad.uop.ops import UOp
from tinygrad.uop.decompositions import fast_idiv
@ -16,7 +15,7 @@ if __name__ == "__main__":
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
d = random.randint(1, max(1, u.arg[2])*2)
if d in powers_of_two: continue
expr = fast_idiv(DEV.target(Device.DEFAULT), u, d)
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
if expr is None: continue
solver = z3.Solver()

View file

@ -6,7 +6,6 @@ import examples.mlperf.metrics as metrics
from tinygrad.helpers import fetch
from test.helpers import slow
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import is_dtype_supported
import numpy as np
# Audio generated with the command on MacOS:
@ -53,7 +52,6 @@ def wer_helper(result: str, reference: str)->float:
return wer
@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
# TODO: WEBGPU GPU dispatch dimensions limit
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GPU dispatch dimensions limit")
class TestWhisper(unittest.TestCase):

View file

@ -2,7 +2,6 @@ import unittest, io
from contextlib import redirect_stdout
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import OSX, DEV
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import compile_linear
from tinygrad.codegen import to_program
@ -10,7 +9,6 @@ class TestCompileFailures(unittest.TestCase):
def compile(self, out:Tensor):
compile_linear(out.schedule_linear())
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
def test_interpolate_atari(self):
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))

View file

@ -1,7 +1,6 @@
import unittest, math, struct, operator
from tinygrad.tensor import Tensor, dtypes
from tinygrad.dtype import DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
from tinygrad.device import is_dtype_supported
from tinygrad import Tensor, Device
from tinygrad.dtype import DTYPES_DICT, dtypes, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
from tinygrad.helpers import getenv
from hypothesis import given, settings, strategies as strat
@ -12,8 +11,8 @@ settings.register_profile("my_profile", max_examples=50, deadline=None, derandom
settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and dt in Device[Device.DEFAULT].renderer.supported_dtypes()]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and dt in Device[Device.DEFAULT].renderer.supported_dtypes()]
FP8E4M3_MAX = 448.0
FP8E5M2_MAX = 57344.0

View file

@ -1,13 +0,0 @@
import unittest
from tinygrad import dtypes, Device
from tinygrad.device import is_dtype_supported
@unittest.skipUnless(Device.DEFAULT=="NULL", "Don't run when testing non-NULL backends")
class TestNULLSupportsDTypes(unittest.TestCase):
def test_null_supports_ints_floats_bool(self):
dts = dtypes.ints + dtypes.floats + (dtypes.bool,)
not_supported = [dt for dt in dts if not is_dtype_supported(dt)]
self.assertFalse(not_supported, msg=f"expected these dtypes to be supported by NULL: {not_supported}")
if __name__ == "__main__":
unittest.main()

View file

@ -1,6 +1,5 @@
import unittest, time, gc
import numpy as np
from tinygrad.device import is_dtype_supported
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.engine.jit import TinyJit
@ -41,6 +40,8 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
if all_jitted:
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
class TestRealWorld(unittest.TestCase):
def setUp(self):
gc.collect()
@ -53,7 +54,7 @@ class TestRealWorld(unittest.TestCase):
dtypes.default_float = self.old_float
@slow
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_stable_diffusion(self):
params = unet_params
params["model_ch"] = 8
@ -78,7 +79,7 @@ class TestRealWorld(unittest.TestCase):
exp_mem = 0.00037 if Device.DEFAULT == "CL" else 0.0002
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, exp_mem, 37)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_llama(self):
dtypes.default_float = dtypes.float16
@ -90,7 +91,7 @@ class TestRealWorld(unittest.TestCase):
# TODO: test first token vs rest properly
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.23, 118, all_jitted=True)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_gpt2(self):
dtypes.default_float = dtypes.float16
@ -147,7 +148,7 @@ class TestRealWorld(unittest.TestCase):
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, 0.12, 126)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16
with Tensor.train():

View file

@ -2,7 +2,6 @@
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
@ -86,12 +85,12 @@ class TestIdxUpcast(unittest.TestCase):
def do_op_then_assert(self, dtype: DType, dim1, dim2, dim3):
self._assert(dtype, Tensor.empty(dim1, dim2, 1).expand(-1, -1, dim3).contiguous())
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
@unittest.skipUnless(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
def test_overflow(self):
# 2**11, 2**11, 2**11 -> 2**33 will overflow when indexed
self.do_op_then_assert(dtypes.long, 2048, 2048, 2048)
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
@unittest.skipUnless(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
def test_overflow_sym(self):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
@ -108,12 +107,12 @@ class TestIdxUpcast(unittest.TestCase):
uops = self._schedule_render(a)
assert all(uop.dtype is not dtypes.long for uop in uops)
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
@unittest.skipIf(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
def test_int64_unsupported_overflow_sym(self):
with self.assertRaises((KeyError, RuntimeError)):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
@unittest.skipIf(dtypes.long in Device[Device.DEFAULT].renderer.supported_dtypes(), "int64 is supported")
@unittest.expectedFailure # bug in gpu dims limiting
def test_int64_unsupported_overflow(self):
with self.assertRaises((KeyError, RuntimeError)):

View file

@ -5,7 +5,7 @@ from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.uop.ops import Ops, UOp, buffers
from tinygrad.dtype import DType
from tinygrad.device import Buffer, is_dtype_supported
from tinygrad.device import Buffer
from tinygrad.helpers import DEV, Context
from test.helpers import slow, replace_opts
from tinygrad.engine.realize import run_linear
@ -69,7 +69,6 @@ class TestTensorCores(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
@ -78,7 +77,6 @@ class TestTensorCores(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_codegen(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2]
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, dtype=tc.dtype_out)
@ -98,7 +96,6 @@ class TestTensorCores(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_padded(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
# AMD compiler bug: AMD miscompiles non-zero padded tc kernels with -O3, producing wrong results, nans or hang (see #9606)
@ -109,7 +106,6 @@ class TestTensorCores(unittest.TestCase):
@unittest.skip("warp elements not duplicated properly across lanes")
def test_tensor_cores_padded_amd(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
@Context(ALLOW_TF32=1)
@ -140,7 +136,6 @@ class TestTensorCores(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_multi_reduce(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
if tc.dtype_in is dtypes.bfloat16: continue # <-- broken with numpy
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
golden_result = None

View file

@ -1,7 +1,6 @@
import unittest
from tinygrad.helpers import CI
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import is_dtype_supported
# similar to test/external/external_test_gpu_ast.py, but universal
@unittest.skipIf(Device.DEFAULT in {"CUDA", "NV"} and CI, "slow on CUDA CI")
@ -20,7 +19,7 @@ class TestSpecific(unittest.TestCase):
w = Tensor.randn(2048, 512)
(x @ w).reshape(1, 128, 4).contiguous().realize()
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
@unittest.skipUnless(dtypes.float16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need float16 support")
def test_big_vec_mul(self):
# from LLaMA
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]

View file

@ -1,12 +1,10 @@
import unittest
from extra.f16_decompress import u32_to_f16
from tinygrad.tensor import Tensor
from tinygrad.device import is_dtype_supported
from tinygrad import dtypes
import numpy as np
class TestF16Decompression(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_u32_to_f16(self):
a = Tensor.randn(50, dtype=dtypes.float16)
f16_as_u32 = a.bitcast(dtypes.uint32)

View file

@ -4,12 +4,11 @@ import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.nn import Linear
from extra.fp8.fp8_linear import FP8Linear, convert_to_float8_training
from tinygrad.device import is_dtype_supported
from test.helpers import not_support_multi_device, needs_second_gpu
BS, T, in_dim, out_dim = 16, 4, 128, 128
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.fp8e4m3 in Device[Device.DEFAULT].renderer.supported_dtypes(), f"no fp8e4m3 on {Device.DEFAULT}")
class TestFP8Linear(unittest.TestCase):
def setUp(self):
Tensor.manual_seed(42)

View file

@ -3,7 +3,6 @@ import unittest
import numpy as np
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad.uop.ops import Ops, UOp
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import temp, CI, DEV, Context
N = 200 # has to be bigger than the cache to fail
@ -468,7 +467,6 @@ class TestAssign(unittest.TestCase):
self.assertEqual(GlobalCounters.kernel_count, 2) # currently conservative, forces contiguous
np.testing.assert_allclose(a.numpy(), expected)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_setitem_half(self):
a = Tensor.full((8,), 1.0, dtype=dtypes.half).contiguous().realize()
b = Tensor.full((4,), 2.0, dtype=dtypes.half).contiguous().realize()

View file

@ -1,7 +1,6 @@
import os, pathlib, tempfile, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, DTYPES_DICT
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, OSX, dedup
@ -36,7 +35,6 @@ class TestTorchLoad(TempDirTestCase):
# pytorch zip format
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# pytorch tar format
@ -94,7 +92,6 @@ class TestRawDiskBuffer(unittest.TestCase):
pathlib.Path(tmp).unlink()
@unittest.skipUnless(is_dtype_supported(dtypes.uint8), "need uint8")
class TestSafetensors(TempDirTestCase):
def test_real_safetensors(self):
import torch
@ -184,7 +181,6 @@ class TestSafetensors(TempDirTestCase):
def test_save_all_dtypes(self):
for dtype in dedup(DTYPES_DICT.values()):
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
if not is_dtype_supported(dtype): continue
path = self.tmp(f"ones.{dtype}.safetensors")
ones = Tensor(np.random.rand(10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)
@ -384,7 +380,6 @@ class TestDiskTensor(TempDirTestCase):
assert ret.tolist() == [2827, 3341, 3855, 4369]
@unittest.skipIf(OSX or Device.DEFAULT == "CL", "new LLVM has an issue on OSX, DEV=CL gives the wrong output")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
t.to(f"disk:{self.tmp('dt_bf16_disk_write_read_f32')}").realize()

View file

@ -1,7 +1,7 @@
import unittest, math, subprocess
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, DType, DTYPES_DICT
from tinygrad.device import Device, is_dtype_supported
from tinygrad.device import Device
from tinygrad.helpers import getenv, DEBUG, EMULATED_DTYPES
from test.helpers import slow
from hypothesis import given, settings, strategies as strat
@ -11,9 +11,10 @@ import torch
settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
core_dtypes = list(DTYPES_DICT.values())
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and dt in supported_dtypes]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and dt in supported_dtypes]
FP8E4M3_MAX = 448.0
FP8E5M2_MAX = 57344.0
@ -25,7 +26,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
try:
assert tensor.dtype == target_dtype
# denormals are zero
if target_dtype in dtypes.floats and (not is_dtype_supported(target_dtype) or target_dtype in EMULATED_DTYPES.tolist(dtypes)):
if target_dtype in dtypes.floats and (target_dtype not in supported_dtypes or target_dtype in EMULATED_DTYPES.tolist(dtypes)):
fe, fm = dtypes.finfo(target_dtype)
kwargs = {"atol":2 ** (2 - (1 << (fe - 1))), "rtol": 2 ** (-fm)}
else: kwargs = {"rtol": {dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1,
@ -58,7 +59,7 @@ class TestTypeSpec(unittest.TestCase):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
@unittest.skipUnless(dtypes.int8 in supported_dtypes, f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
@ -91,7 +92,7 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
if dtypes.float16 in supported_dtypes:
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
@ -100,12 +101,12 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
if dtypes.float16 in supported_dtypes:
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
if dtypes.float16 in supported_dtypes:
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
@ -113,7 +114,7 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
if dtypes.float16 in supported_dtypes:
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@ -132,10 +133,10 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
if is_dtype_supported(dtypes.int16):
if dtypes.int16 in supported_dtypes:
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
if dtypes.float16 in supported_dtypes:
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
@ -149,7 +150,7 @@ class TestAutoCastType(unittest.TestCase):
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and d in supported_dtypes]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
@ -167,7 +168,7 @@ class TestAutoCastType(unittest.TestCase):
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need float16")
def test_sum_dtype_arg(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
@ -188,10 +189,10 @@ class TestAutoCastType(unittest.TestCase):
old_default_float = dtypes.default_float
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
if default_dtype not in supported_dtypes: continue
dtypes.default_float = default_dtype
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
for dtype in dtypes.floats:
if dtype not in supported_dtypes: continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype)
@ -205,7 +206,7 @@ class TestAutoCastType(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@slow
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
def test_mean_half_precision_underflow(self):
N = 10000
x = 0.001
@ -213,7 +214,7 @@ class TestAutoCastType(unittest.TestCase):
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
def test_mean_half_precision_overflow(self):
N = 256
t = Tensor([60000] * N*N, dtype=dtypes.half).reshape(N, N)
@ -222,7 +223,7 @@ class TestAutoCastType(unittest.TestCase):
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)

View file

@ -3,12 +3,12 @@ from tinygrad import dtypes, Tensor, fetch, Device
from tinygrad.helpers import disable_gc
from tinygrad.llm.gguf import _ggml_iq_grid, ggml_data_to_tensor, gguf_load
from tinygrad.runtime.autogen import ggml_common as _ggml
from tinygrad.device import is_dtype_supported
import numpy as np
from gguf import GGUFReader, GGUFValueType, GGMLQuantizationType, GGML_QUANT_SIZES, dequantize, quantize
from gguf.quants import IQ2_S, IQ3_S, IQ3_XXS
ggml_test_block_count = 4
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
class TestGGUFTables(unittest.TestCase):
def test_iq2_s_grid_matches_gguf_py(self):
@ -26,7 +26,7 @@ class TestGGUFTables(unittest.TestCase):
grid = _ggml_iq_grid(Device.DEFAULT, _ggml.iq3s_grid, (512, 4)).numpy()
np.testing.assert_equal(grid, IQ3_S.grid.reshape(512, 4))
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
@unittest.skipUnless(dtypes.uint8 in supported_dtypes and dtypes.half in supported_dtypes, "Backend must support uint8 and half")
class TestGGUF(unittest.TestCase):
def test_load_tinyllama_q8_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
def test_load_tinyllama_q4_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
@ -60,7 +60,7 @@ class TestGGUF(unittest.TestCase):
def test_dequantization_iq2_s(self): self._test_dequantization(GGMLQuantizationType.IQ2_S)
def test_dequantization_iq4_xs(self): self._test_dequantization(GGMLQuantizationType.IQ4_XS)
def test_dequantization_mxfp4(self): self._test_dequantization(GGMLQuantizationType.MXFP4)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16")
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, "Backend must support bfloat16")
def test_dequantization_bf16(self): self._test_dequantization(GGMLQuantizationType.BF16)
def test_dequantization_mxfp4_old(self):
def encode(nibbles, E):
@ -229,7 +229,7 @@ class TestGGUFGEMV(unittest.TestCase):
x = rng.standard_normal(cols).astype(np.float32)
with np.errstate(all='ignore'):
np.testing.assert_allclose((tensors["weight"] @ Tensor(x)).numpy(), ref @ x, atol=1e-2, rtol=1e-2)
if qtype == GGMLQuantizationType.BF16 or is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref)
if qtype == GGMLQuantizationType.BF16 or dtypes.half in supported_dtypes: np.testing.assert_equal(tensors["weight"].numpy(), ref)
assert np.isfinite(ref).all() and np.isfinite(tensors["weight"].numpy()).all(), f"{qtype.name} has NaN/Inf"
def test_gguf_gemv_q8_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q8_0)
@ -243,7 +243,7 @@ class TestGGUFGEMV(unittest.TestCase):
def test_gguf_gemv_iq2_s(self): self._test_gguf_gemv(GGMLQuantizationType.IQ2_S)
def test_gguf_gemv_iq4_xs(self): self._test_gguf_gemv(GGMLQuantizationType.IQ4_XS)
def test_gguf_gemv_mxfp4(self): self._test_gguf_gemv(GGMLQuantizationType.MXFP4)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16")
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, "Backend must support bfloat16")
def test_gguf_gemv_bf16(self): self._test_gguf_gemv(GGMLQuantizationType.BF16)
class TestGGUFGC(unittest.TestCase):

View file

@ -3,11 +3,12 @@ import hashlib, random, unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import DEV
from test.helpers import slow
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import UOp
from tinygrad.engine.jit import TinyJit
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
supported_dtypes = Device[Device.DEFAULT].renderer.supported_dtypes()
@unittest.skipUnless(dtypes.uint8 in supported_dtypes and dtypes.uint64 in supported_dtypes, "Device must support uint8 and uint64")
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV", "crashes in NV CI")
class TestHashing(unittest.TestCase):
def _python_hash_1mb(self, data:bytes):
@ -21,7 +22,7 @@ class TestHashing(unittest.TestCase):
out = Tensor(b"abc").hash()
self.assertEqual(bytes(out.data()), expected)
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
@unittest.skipUnless(dtypes.uint8 in supported_dtypes and dtypes.uint64 in supported_dtypes, "Device must support uint8 and uint64")
@unittest.skipIf(DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV", "crashes in NV CI")
class TestKeccak(unittest.TestCase):
def setUp(self) -> None: random.seed(1337)

View file

@ -89,8 +89,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren.target, name="decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.target), name="decomp dtypes")
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# move gates from unrenderable INVALID where
@ -99,7 +99,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.target, name="final rewrite")
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)

View file

@ -2,12 +2,11 @@ from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Any, Generic, TypeVar, Iterator, Generator, TYPE_CHECKING
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import BENCHMARKS, CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
import importlib, inspect, functools, pathlib, os, contextlib, re, atexit, pickle, decimal
from tinygrad.helpers import LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, suppress_finalizing
from tinygrad.helpers import select_by_name, select_first_inited, DEV, EMULATED_DTYPES, IMAGE, FLOAT16, TracingKey, size_to_str, Target
from tinygrad.helpers import pluralize
from tinygrad.dtype import DType, PtrDType, dtypes, _to_np_dtype
from tinygrad.helpers import select_by_name, select_first_inited, DEV, TracingKey, size_to_str, pluralize
from tinygrad.dtype import DType, PtrDType, _to_np_dtype
if TYPE_CHECKING: from tinygrad.renderer import Renderer
# **************** Device ****************
@ -336,47 +335,6 @@ class Compiled:
"""
# override this in your device implementation
# TODO: move this to each Device
# this only tracks if the dtype is natively supported, it may be supported in the frontend using decomps
def is_dtype_supported(dtype:DType, target:Target|None=None) -> bool:
target = target or DEV.target(Device.DEFAULT)
if dtype == dtypes.bfloat16:
match target.device:
case "METAL": target.arch.startswith("Apple") and int(target.arch[5:]) >= 6
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
case "CPU": return platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and target.renderer not in ("LVP", "X86")
case "AMD" | "CL" | "PYTHON" | "NULL": return True
case _: return False
if dtype in dtypes.fp8_ocp:
match target.device:
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
case "AMD": return (not CI or BENCHMARKS) and target.arch == "gfx950"
case "PYTHON" | "NULL": return True
case _: return False
if dtype in dtypes.fp8_fnuz: return target.device in {"PYTHON", "NULL"}
if target.device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
match target.device:
case "CL": return "cl_khr_fp16" in target.arch
case "QCOM": return bool(IMAGE) and bool(FLOAT16) # QCOM compiler is flaky with half
case "CUDA" | "NV": return not CI or BENCHMARKS or target.renderer == "PYTHON"
case "CPU" if target.renderer == "LLVM": return OSX
case "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64:
match target.device:
case _ if dtypes.long in EMULATED_DTYPES.tolist(dtypes): return False # double can't be bitcast to anything without long support
case "CL": return "cl_khr_fp64" in target.arch
case "NULL": return target.renderer not in ("IR3", "QCOMCL")
case "METAL" | "QCOM": return False
return True
if PROFILE:
@atexit.register
def finalize_profile():

View file

@ -14,7 +14,7 @@ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX, WIN = platform.system() == "Darwin", sys.platform == "win32"
CI, BENCHMARKS = os.getenv("CI", "") != "", os.getenv("RUNNER_ENVIRONMENT", "") == "self-hosted"
CI = os.getenv("CI", "") != ""
ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64"))
BASEDIR = pathlib.Path(__file__).parent

View file

@ -1,9 +1,9 @@
from __future__ import annotations
from typing import Callable, cast
from dataclasses import dataclass
from tinygrad.helpers import prod, Target
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler
@ -85,3 +85,6 @@ class Renderer:
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
def asm(self, prg:UOp, lin:UOp) -> bytes: raise NotImplementedError("needs an assembler")
def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux")
def supported_dtypes(self) -> set[DType]:
# double can't be bitcast to anything without long support
return set(dtypes.all) - {dtypes.weakint} - ({dtypes.double} if dtypes.long in EMULATED_DTYPES.tolist(dtypes) else set())

View file

@ -3,7 +3,7 @@ import math, sys, struct
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
from tinygrad.helpers import strip_parens, getenv, prod, dedup, Target, CPU_COUNT
from tinygrad.helpers import strip_parens, getenv, prod, dedup, Target, CPU_COUNT, IMAGE, FLOAT16
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
@ -272,6 +272,9 @@ class ClangRenderer(CStyleLanguage):
defines = '\n'.join(self._render_defines(uops))
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
def supported_dtypes(self):
return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or self.target.arch.startswith(("x86", "arm"))) and d not in dtypes.fp8s}
class ClangJITRenderer(ClangRenderer):
def __init__(self, target:Target):
super().__init__(target)
@ -321,6 +324,10 @@ class OpenCLRenderer(CStyleLanguage):
arg_dtypes[u.arg].append((i, u.dtype))
return tuple(tuple(a) for a in arg_dtypes),
def supported_dtypes(self): return {d for d in super().supported_dtypes()
if (d != dtypes.half or "cl_khr_fp16" in self.target.arch) and
(d != dtypes.double or "cl_khr_fp64" in self.target.arch) and d not in dtypes.fp8s}
class IntelRenderer(OpenCLRenderer):
suffix, kernel_typedef = "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
tensor_cores = tc.intel
@ -382,6 +389,10 @@ class MetalRenderer(CStyleLanguage):
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def supported_dtypes(self):
return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or ((arch:=self.target.arch).startswith("Apple") and int(arch[5:]) >= 6))
and d not in dtypes.fp8s+(dtypes.double,)}
_nms = list("xyzwabcdefghijkl") + [f'v{i}' for i in range(16, 32)]
class CUDARenderer(CStyleLanguage):
@ -456,6 +467,11 @@ class CUDARenderer(CStyleLanguage):
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
def supported_dtypes(self):
ver = int(self.target.arch[3:])
return {d for d in super().supported_dtypes() if (d != dtypes.half or ver >= 53) and (d != dtypes.bfloat16 or ver >= 80)
and (d not in dtypes.fp8_ocp or ver >= 89) and d not in dtypes.fp8_fnuz}
class NVCCRenderer(CUDARenderer):
def __init__(self, target:Target): super().__init__(target, use_nvcc=True)
@ -559,6 +575,9 @@ class HIPRenderer(CStyleLanguage):
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def supported_dtypes(self): return {d for d in super().supported_dtypes()
if (d not in dtypes.fp8_ocp or self.target.arch == "gfx950") and d not in dtypes.fp8_fnuz}
class HIPCCRenderer(HIPRenderer):
def __init__(self, target:Target): super().__init__(target, use_hipcc=True)
@ -567,3 +586,8 @@ class QCOMCLRenderer(OpenCLRenderer):
super().__init__(target)
from tinygrad.runtime.support.compiler_qcom import QCOMCompiler
self.compiler = QCOMCompiler(target.arch)
# QCOM compiler is flaky with half
def supported_dtypes(self):
return {d for d in Renderer.supported_dtypes(self)
if (d != dtypes.float16 or (bool(IMAGE) and bool(FLOAT16))) and d not in dtypes.fp8s+(dtypes.bfloat16,dtypes.double)}

View file

@ -894,3 +894,5 @@ class X86Renderer(ISARenderer):
for u in uops:
if (t:=jumps.get(u)) is not None: binary[t-4:t] = (targets[u.tag] - t).to_bytes(4, 'little', signed=True)
return binary.hex()
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d not in dtypes.fp8s+(dtypes.bfloat16,)}

View file

@ -6,7 +6,7 @@ from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats,
from tinygrad.uop.decompositions import xexp2, xlog2
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
@ -206,6 +206,11 @@ class CPULLVMRenderer(LLVMRenderer):
if "AMX" in target.arch: self.tensor_cores = tc.amx
self.compiler = CPULLVMCompiler([x for x in target.arch.split(",") if x != "AMX"])
# FIXME: fp16 works on non-osx, but only if the cpu supports it
def supported_dtypes(self):
return {d for d in super().supported_dtypes() if
(d != dtypes.bfloat16 or self.target.arch.startswith(("x86", "arm"))) and (d != dtypes.half or OSX) and d not in dtypes.fp8s}
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
@ -291,3 +296,6 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
])
def supported_dtypes(self): return {d for d in super().supported_dtypes()
if (d not in dtypes.fp8_ocp or self.target.arch == "gfx950") and d not in dtypes.fp8_fnuz}

View file

@ -226,11 +226,15 @@ class NIRRenderer(Renderer):
return ret
def supported_dtypes(self): return {d for d in Renderer.supported_dtypes(self) if d not in dtypes.fp8s+(dtypes.bfloat16,)}
class NAKRenderer(NIRRenderer):
param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz),
intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])(
lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv))
def supported_dtypes(self): return {d for d in super().supported_dtypes() if (d != dtypes.half or int(self.target.arch[3:]) >= 53)}
class LVPRenderer(NIRRenderer):
has_local = False
has_shared = False
@ -294,3 +298,5 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])
self.b.shader.contents.info.num_images = texs() + imgs()
def supported_dtypes(self): return {d for d in NIRRenderer.supported_dtypes(self) if d != dtypes.double}

View file

@ -239,3 +239,6 @@ class PTXRenderer(Renderer):
if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg};"] + kernel
return self.render_kernel(kernel, name, bufs, c.items(), uops)
def supported_dtypes(self): return {d for d in super().supported_dtypes()
if (d != dtypes.half or int(self.target.arch[3:]) >= 53) and d not in dtypes.fp8s+(dtypes.bfloat16,)}

View file

@ -112,3 +112,6 @@ class WGSLRenderer(CStyleLanguage):
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
def supported_dtypes(self):
return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half}

View file

@ -68,6 +68,8 @@ class DSPRenderer(ClangRenderer):
msrc += ["return 0; }"]
return '\n'.join(msrc)
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d not in dtypes.fp8s+(dtypes.bfloat16,)}
def rpc_sc(method=0, ins=0, outs=0, fds=0): return (method << 24) | (ins << 16) | (outs << 8) | fds
def rpc_prep_args(ins=None, outs=None, in_fds=None):
ins, outs, in_fds = ins or list(), outs or list(), in_fds or list()

View file

@ -229,6 +229,8 @@ class PythonRenderer(Renderer):
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()
def supported_dtypes(self): return {d for d in super().supported_dtypes() if d != dtypes.half or sys.version_info >= (3, 12)}
class PythonAllocator(Allocator['PythonDevice']):
def _alloc(self, size, options): return memoryview(bytearray(size))
def _copyin(self, dest, src:memoryview): dest[:] = src

View file

@ -1,10 +1,10 @@
from typing import Callable
import math, functools
from tinygrad.dtype import dtypes, DType, promo_lattice, truncate
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import flatten, polyN, Target, EMULATED_DTYPES
from tinygrad.helpers import flatten, polyN, EMULATED_DTYPES
from tinygrad.uop import GroupOp
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
from tinygrad.renderer import Renderer
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
@ -279,9 +279,10 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
return m, s
assert False
def fast_idiv(target: Target, x: UOp, d: int, dont_cast=False) -> UOp|None:
def fast_idiv(ren: Renderer, x: UOp, d: int, dont_cast=False) -> UOp|None:
from tinygrad.renderer.cstyle import MetalRenderer
# NOTE: disable for METAL due to compiler bug. keccak with -O0 works but not with optimization
if target.device.startswith("METAL"): return None
if isinstance(ren, MetalRenderer): return None
# If d is a power of two this is not valid for signed ints!
is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints
assert d>0, "Sign should have been taken out of divisor"
@ -293,11 +294,11 @@ def fast_idiv(target: Target, x: UOp, d: int, dont_cast=False) -> UOp|None:
# before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller
# use explicit Ops.CDIV (trunc) since the recursion assumes trunc semantics throughout
if (largest_factor_of_two_in_d := (d & -d)) > 1:
if (ret:=fast_idiv(target, x.alu(Ops.CDIV, x.const_like(largest_factor_of_two_in_d)),
if (ret:=fast_idiv(ren, x.alu(Ops.CDIV, x.const_like(largest_factor_of_two_in_d)),
d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
if dont_cast: return None
# promo_lattice needs to return an unsigned type if the type is unsigned
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, target):
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and next_dtype in ren.supported_dtypes():
if m*vmin >= next_dtype.min and m*vmax <= next_dtype.max:
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
return None
@ -553,8 +554,8 @@ pm_float_decomp = PatternMatcher([
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
])
def do_dtype_decomps(sink:UOp, ctx:tuple[set[DType], Target]) -> UOp:
def _should_emulate(dt): return dt in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dt, ctx[1])
def do_dtype_decomps(sink:UOp, ctx:tuple[set[DType], Renderer]) -> UOp:
def _should_emulate(dt): return dt in EMULATED_DTYPES.tolist(dtypes) or dt not in ctx[1].supported_dtypes()
for fr in sorted(filter(_should_emulate, ctx[0])):
if fr in dtypes.floats:
to = dtypes.half if not _should_emulate(dtypes.half) and fr in dtypes.fp8s else dtypes.float