mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
42b6bf0b7a
commit
9759fd6193
6 changed files with 112 additions and 95 deletions
|
|
@ -203,7 +203,7 @@ class TestPatternMatcher(unittest.TestCase):
|
|||
|
||||
def _assert_eq_upat(self, a:UPat, b:UPat):
|
||||
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
|
||||
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
|
||||
assert (sorted(a.match_dtype) if a.match_dtype else [] == (sorted(b.match_dtype) if b.match_dtype else []))
|
||||
assert (a.name, type(a.src)) == (b.name, type(b.src))
|
||||
def simple_src(u:UPat):
|
||||
if u.src is None: return []
|
||||
|
|
|
|||
81
tinygrad/mixin/dtype.py
Normal file
81
tinygrad/mixin/dtype.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Self
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
|
||||
class DTypeMixin:
|
||||
@property
|
||||
def dtype(self) -> DType: raise NotImplementedError
|
||||
|
||||
def cast(self, dtype:DType) -> Self: raise NotImplementedError
|
||||
|
||||
def element_size(self) -> int:
|
||||
"""Returns the number of bytes of a single element in the tensor."""
|
||||
return self.dtype.itemsize
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"""Returns `True` if the tensor contains floating point types, i.e. is one of `bool`, `float16`, `bfloat16`, `float32`, `float64`."""
|
||||
return dtypes.is_float(self.dtype)
|
||||
|
||||
def float(self) -> Self:
|
||||
"""
|
||||
Convenience method to cast `self` to a `float32` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.float()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.float32)
|
||||
|
||||
def half(self) -> Self:
|
||||
"""
|
||||
Convenience method to cast `self` to a `float16` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.half()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.float16)
|
||||
|
||||
def int(self) -> Self:
|
||||
"""
|
||||
Convenience method to cast `self` to a `int32` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.int()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.int32)
|
||||
|
||||
def bool(self) -> Self:
|
||||
"""
|
||||
Convenience method to cast `self` to a `bool` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 0, 1])
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.bool()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.bool)
|
||||
|
||||
def bfloat16(self) -> Self: return self.cast(dtypes.bfloat16)
|
||||
def double(self) -> Self: return self.cast(dtypes.double)
|
||||
def long(self) -> Self: return self.cast(dtypes.long)
|
||||
def short(self) -> Self: return self.cast(dtypes.short)
|
||||
|
|
@ -2,9 +2,10 @@ import math
|
|||
from typing import Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.mixin.dtype import DTypeMixin
|
||||
|
||||
|
||||
class MathMixin:
|
||||
class MathMixin(DTypeMixin):
|
||||
# required to implement
|
||||
def alu(self, op: Ops, *src: Self) -> Self:
|
||||
raise NotImplementedError
|
||||
|
|
@ -23,16 +24,11 @@ class MathMixin:
|
|||
return self.ne(True)
|
||||
|
||||
def neg(self) -> Self:
|
||||
if (dtype := getattr(self, "dtype")) is None:
|
||||
raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self * (-1)
|
||||
return self.logical_not() if self.dtype.scalar() == dtypes.bool else self * (-1)
|
||||
|
||||
def _check_dtype(self) -> None:
|
||||
if (dtype := getattr(self, "dtype")) is not None:
|
||||
if isinstance(dtype, tuple):
|
||||
dtype = dtype[0]
|
||||
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)):
|
||||
raise RuntimeError(f"{dtype} is not supported")
|
||||
if not (dtypes.is_bool(self.dtype) or dtypes.is_int(self.dtype)):
|
||||
raise RuntimeError(f"{self.dtype} is not supported")
|
||||
|
||||
def add(self, x: Self | ConstType, reverse: bool = False) -> Self:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3843,71 +3843,6 @@ class Tensor(OpMixin):
|
|||
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
||||
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
|
||||
|
||||
def float(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `float32` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.float()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.float32)
|
||||
|
||||
def half(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `float16` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.half()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.float16)
|
||||
|
||||
def int(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `int32` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.int()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.int32)
|
||||
|
||||
def bool(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `bool` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 0, 1])
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.bool()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.bool)
|
||||
|
||||
def bfloat16(self) -> Tensor: return self.cast(dtypes.bfloat16)
|
||||
def double(self) -> Tensor: return self.cast(dtypes.double)
|
||||
def long(self) -> Tensor: return self.cast(dtypes.long)
|
||||
def short(self) -> Tensor: return self.cast(dtypes.short)
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -893,13 +893,13 @@ def get_location() -> tuple[str, int]:
|
|||
return frm.f_code.co_filename, frm.f_lineno
|
||||
|
||||
class UPat(OpMixin):
|
||||
__slots__ = ("op", "dtype", "arg", "name", "src", "is_any")
|
||||
__slots__ = ("op", "match_dtype", "arg", "name", "src", "is_any")
|
||||
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
|
||||
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
|
||||
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
|
||||
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
|
||||
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
||||
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
|
||||
self.match_dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
self.src: Any = None
|
||||
self.is_any = is_any
|
||||
|
|
@ -922,9 +922,14 @@ class UPat(OpMixin):
|
|||
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
||||
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType: return self.match_dtype[0] if self.match_dtype is not None else dtypes.void
|
||||
|
||||
def _check_dtype(self) -> None: pass
|
||||
|
||||
def __reduce__(self):
|
||||
return UPat, (self.op, self.dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
|
||||
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
|
||||
return UPat, (self.op, self.match_dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
|
||||
def named(self, name:str): return UPat(self.op, self.match_dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
|
||||
|
||||
@staticmethod
|
||||
def any(*src): return UPat(src=src, is_any=True)
|
||||
|
|
@ -948,23 +953,23 @@ class UPat(OpMixin):
|
|||
# copied from UOp
|
||||
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, idx:UPat, valid:UPat|None=None, **kwargs):
|
||||
return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
|
||||
return UPat(Ops.INDEX, self.match_dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
|
||||
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
||||
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs)
|
||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
|
||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.dtype, (self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
|
||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
|
||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
|
||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
|
||||
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].match_dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
|
||||
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
||||
if self.is_any:
|
||||
|
|
@ -972,7 +977,7 @@ class UPat(OpMixin):
|
|||
return flatten([x for x in matches if x is not None])
|
||||
if (self.op is not None and uop.op not in self.op) or \
|
||||
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
||||
(self.match_dtype is not None and uop.dtype not in self.match_dtype and uop.dtype.scalar() not in self.match_dtype) or \
|
||||
(self.arg is not None and self.arg != uop.arg) or \
|
||||
(len(uop.src) < self.required_len) or \
|
||||
(self.strict_length and len(uop.src) != self.required_len): return []
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
|
|||
if self.strict_length or self.required_len > 0:
|
||||
and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg=("len({0}.src)"+(" == " if self.strict_length else " >= ")+str(self.required_len))))
|
||||
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.DEFINE_VAR, arg=self.name), base)))
|
||||
if self.dtype is not None:
|
||||
if len(self.dtype) > 1:
|
||||
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
|
||||
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.dtype[0])), arg="({0}.dtype == {1} or {0}.dtype._scalar == {1})"))
|
||||
if self.match_dtype is not None:
|
||||
if len(self.match_dtype) > 1:
|
||||
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.match_dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
|
||||
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.match_dtype[0])), arg="({0}.dtype == {1} or {0}.dtype._scalar == {1})"))
|
||||
if self.src is not None:
|
||||
# single match
|
||||
if len(self.src) == 1 and isinstance(self.src[0], tuple):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue