dtype mixin (#14763)

* dtype mixin

* dtype mixin methods
This commit is contained in:
George Hotz 2026-02-15 16:03:48 +08:00 committed by GitHub
commit 9759fd6193
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 112 additions and 95 deletions

View file

@ -203,7 +203,7 @@ class TestPatternMatcher(unittest.TestCase):
def _assert_eq_upat(self, a:UPat, b:UPat):
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
assert (sorted(a.match_dtype) if a.match_dtype else [] == (sorted(b.match_dtype) if b.match_dtype else []))
assert (a.name, type(a.src)) == (b.name, type(b.src))
def simple_src(u:UPat):
if u.src is None: return []

81
tinygrad/mixin/dtype.py Normal file
View file

@ -0,0 +1,81 @@
from typing import Self
from tinygrad.dtype import DType, dtypes
class DTypeMixin:
@property
def dtype(self) -> DType: raise NotImplementedError
def cast(self, dtype:DType) -> Self: raise NotImplementedError
def element_size(self) -> int:
"""Returns the number of bytes of a single element in the tensor."""
return self.dtype.itemsize
def is_floating_point(self) -> bool:
"""Returns `True` if the tensor contains floating point types, i.e. is one of `bool`, `float16`, `bfloat16`, `float32`, `float64`."""
return dtypes.is_float(self.dtype)
def float(self) -> Self:
"""
Convenience method to cast `self` to a `float32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.float()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float32)
def half(self) -> Self:
"""
Convenience method to cast `self` to a `float16` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.half()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float16)
def int(self) -> Self:
"""
Convenience method to cast `self` to a `int32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.int()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.int32)
def bool(self) -> Self:
"""
Convenience method to cast `self` to a `bool` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 0, 1])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.bool()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.bool)
def bfloat16(self) -> Self: return self.cast(dtypes.bfloat16)
def double(self) -> Self: return self.cast(dtypes.double)
def long(self) -> Self: return self.cast(dtypes.long)
def short(self) -> Self: return self.cast(dtypes.short)

View file

@ -2,9 +2,10 @@ import math
from typing import Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType
from tinygrad.mixin.dtype import DTypeMixin
class MathMixin:
class MathMixin(DTypeMixin):
# required to implement
def alu(self, op: Ops, *src: Self) -> Self:
raise NotImplementedError
@ -23,16 +24,11 @@ class MathMixin:
return self.ne(True)
def neg(self) -> Self:
if (dtype := getattr(self, "dtype")) is None:
raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
return self.logical_not() if dtype.scalar() == dtypes.bool else self * (-1)
return self.logical_not() if self.dtype.scalar() == dtypes.bool else self * (-1)
def _check_dtype(self) -> None:
if (dtype := getattr(self, "dtype")) is not None:
if isinstance(dtype, tuple):
dtype = dtype[0]
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)):
raise RuntimeError(f"{dtype} is not supported")
if not (dtypes.is_bool(self.dtype) or dtypes.is_int(self.dtype)):
raise RuntimeError(f"{self.dtype} is not supported")
def add(self, x: Self | ConstType, reverse: bool = False) -> Self:
"""

View file

@ -3843,71 +3843,6 @@ class Tensor(OpMixin):
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
def float(self) -> Tensor:
"""
Convenience method to cast `self` to a `float32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.float()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float32)
def half(self) -> Tensor:
"""
Convenience method to cast `self` to a `float16` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.half()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float16)
def int(self) -> Tensor:
"""
Convenience method to cast `self` to a `int32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.int()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.int32)
def bool(self) -> Tensor:
"""
Convenience method to cast `self` to a `bool` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 0, 1])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.bool()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.bool)
def bfloat16(self) -> Tensor: return self.cast(dtypes.bfloat16)
def double(self) -> Tensor: return self.cast(dtypes.double)
def long(self) -> Tensor: return self.cast(dtypes.long)
def short(self) -> Tensor: return self.cast(dtypes.short)
# *** image Tensor function replacements ***
def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:

View file

@ -893,13 +893,13 @@ def get_location() -> tuple[str, int]:
return frm.f_code.co_filename, frm.f_lineno
class UPat(OpMixin):
__slots__ = ("op", "dtype", "arg", "name", "src", "is_any")
__slots__ = ("op", "match_dtype", "arg", "name", "src", "is_any")
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
self.match_dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None
self.is_any = is_any
@ -922,9 +922,14 @@ class UPat(OpMixin):
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
@property
def dtype(self) -> DType: return self.match_dtype[0] if self.match_dtype is not None else dtypes.void
def _check_dtype(self) -> None: pass
def __reduce__(self):
return UPat, (self.op, self.dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
return UPat, (self.op, self.match_dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
def named(self, name:str): return UPat(self.op, self.match_dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
@staticmethod
def any(*src): return UPat(src=src, is_any=True)
@ -948,23 +953,23 @@ class UPat(OpMixin):
# copied from UOp
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def index(self, idx:UPat, valid:UPat|None=None, **kwargs):
return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
return UPat(Ops.INDEX, self.match_dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.dtype, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs)
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.dtype, (self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):
asrc = (self,)+src
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].match_dtype, list(asrc) if op in GroupOp.Commutative else asrc)
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
if self.is_any:
@ -972,7 +977,7 @@ class UPat(OpMixin):
return flatten([x for x in matches if x is not None])
if (self.op is not None and uop.op not in self.op) or \
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
(self.match_dtype is not None and uop.dtype not in self.match_dtype and uop.dtype.scalar() not in self.match_dtype) or \
(self.arg is not None and self.arg != uop.arg) or \
(len(uop.src) < self.required_len) or \
(self.strict_length and len(uop.src) != self.required_len): return []

View file

@ -22,10 +22,10 @@ def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
if self.strict_length or self.required_len > 0:
and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg=("len({0}.src)"+(" == " if self.strict_length else " >= ")+str(self.required_len))))
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.DEFINE_VAR, arg=self.name), base)))
if self.dtype is not None:
if len(self.dtype) > 1:
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.dtype[0])), arg="({0}.dtype == {1} or {0}.dtype._scalar == {1})"))
if self.match_dtype is not None:
if len(self.match_dtype) > 1:
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.match_dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.match_dtype[0])), arg="({0}.dtype == {1} or {0}.dtype._scalar == {1})"))
if self.src is not None:
# single match
if len(self.src) == 1 and isinstance(self.src[0], tuple):