mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
516b00e286
commit
072db9924c
11 changed files with 48 additions and 63 deletions
|
|
@ -66,7 +66,6 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
|
|||
::: tinygrad.Tensor.sub
|
||||
::: tinygrad.Tensor.mul
|
||||
::: tinygrad.Tensor.div
|
||||
::: tinygrad.Tensor.idiv
|
||||
::: tinygrad.Tensor.mod
|
||||
::: tinygrad.Tensor.fmod
|
||||
::: tinygrad.Tensor.bitwise_xor
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ def NF4Linear(block_size):
|
|||
def __call__(self, x: Tensor) -> Tensor:
|
||||
high_bits = self.weight
|
||||
low_bits = (self.weight * 2 ** 4).contiguous()
|
||||
unpacked = Tensor.stack(high_bits, low_bits, dim=-1).idiv(2 ** 4)
|
||||
unpacked = Tensor.stack(high_bits, low_bits, dim=-1).div(2 ** 4, rounding_mode="trunc")
|
||||
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
|
||||
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
|
||||
|
||||
|
|
|
|||
|
|
@ -606,10 +606,11 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]])
|
||||
helper_test_op(None, lambda x: x/2, forward_only=True, vals=[[3, 4, 5]])
|
||||
helper_test_op(None, lambda x: x//2, forward_only=True, vals=[[3, 4, 5]])
|
||||
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
|
||||
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"),
|
||||
functools.partial(Tensor.div, rounding_mode="trunc"), forward_only=True,
|
||||
vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
|
||||
if not COMPILE_ONLY:
|
||||
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
|
||||
x = Tensor(2**64 - 1, dtype=dtypes.uint64).div(1, rounding_mode="trunc")
|
||||
np.testing.assert_equal(x.numpy(), 2**64 - 1)
|
||||
|
||||
def test_scalar_div(self):
|
||||
|
|
@ -878,10 +879,10 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: tor >> 31, lambda: ten >> 31, forward_only=True)
|
||||
|
||||
def test_idiv_shift_rewrite_negative(self):
|
||||
a = Tensor(-5).idiv(2).item()
|
||||
b = Tensor(-5).contiguous().idiv(2).item()
|
||||
a = Tensor(-5).div(2, rounding_mode="trunc").item()
|
||||
b = Tensor(-5).contiguous().div(2, rounding_mode="trunc").item()
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(Tensor(-1).contiguous().idiv(4).item(), 0) # NOTE this is trunc-div behaviour
|
||||
self.assertEqual(Tensor(-1).contiguous().div(4, rounding_mode="trunc").item(), 0) # NOTE this is trunc-div behaviour
|
||||
|
||||
@unittest.skipIf(DEV.renderer == "NAK", "MUFU.SIN is not accurate enough")
|
||||
def test_sin(self):
|
||||
|
|
|
|||
|
|
@ -957,7 +957,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_div_padded_arange(self):
|
||||
x = Tensor.full((2,2), 16)
|
||||
y = x.idiv(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2)).pad(((1,1), (1,1)))
|
||||
y = x.div(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2), rounding_mode="trunc").pad(((1,1), (1,1)))
|
||||
out = y.sum(axis=1)
|
||||
run_linear(*check_schedule(out, 1))
|
||||
self.assertListEqual(out.tolist(), [0, 12, 4, 0])
|
||||
|
|
|
|||
|
|
@ -79,9 +79,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
|||
def test_div_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
|
||||
|
||||
def test_idiv_literal_one(self):
|
||||
def test_floordiv_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
|
||||
def test_idiv_tensor_one(self):
|
||||
def test_floordiv_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
|
||||
|
||||
def test_pow_literal_zero(self):
|
||||
|
|
|
|||
|
|
@ -35,13 +35,14 @@ class TestTensorUOpBinop(unittest.TestCase):
|
|||
def test_isclose(self):
|
||||
t = _t(4).float()
|
||||
self.assertIs(_strip_unique(t.isclose(t).uop), _strip_unique(t.uop.isclose(t.uop)))
|
||||
# __floordiv__/mod/idiv/fmod dispatch on dtype in mixin
|
||||
# __floordiv__/mod/fmod and div(rounding_mode=...) dispatch on dtype in mixin
|
||||
def test_floordiv_int(self): _check(self, _t(4), lambda x: x // 3)
|
||||
def test_floordiv_float(self): _check(self, _t(4).float() + 1.5, lambda x: x // 2.0)
|
||||
def test_rfloordiv_int(self): _check(self, _t(4)+1, lambda x: 7 // x)
|
||||
def test_mod_int(self): _check(self, _t(4), lambda x: x % 3)
|
||||
def test_mod_float(self): _check(self, _t(4).float() + 1.5, lambda x: x % 2.0)
|
||||
def test_idiv_int(self): _check(self, _t(4), lambda x: x.idiv(3))
|
||||
def test_div_trunc_int(self): _check(self, _t(4), lambda x: x.div(3, rounding_mode="trunc"))
|
||||
def test_div_trunc_float(self):_check(self, _t(4).float() + 1.5, lambda x: x.div(2.0, rounding_mode="trunc"))
|
||||
def test_fmod_int(self): _check(self, _t(4), lambda x: x.fmod(3))
|
||||
def test_fmod_float(self): _check(self, _t(4).float() + 1.5, lambda x: x.fmod(2.0))
|
||||
def test_floordiv_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x // True)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
|
||||
|
||||
# ALU ops in index
|
||||
def test_idiv(self):
|
||||
def test_floordiv(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
|||
def q_to_uint8(t: Tensor, b: int) -> Tensor:
|
||||
# TODO: rewrite with arange?
|
||||
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
|
||||
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
||||
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, rounding_mode="trunc").bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
||||
|
||||
# map to (number of elements, number of bytes)
|
||||
if (nelements_nbytes := {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import math, functools, operator
|
||||
from typing import Self
|
||||
from typing import Literal, Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
|
||||
from tinygrad.helpers import argfix, polyN
|
||||
|
|
@ -167,19 +167,6 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
|
|||
self._check_dtype()
|
||||
return self._binop(Ops.XOR, x, reverse)
|
||||
|
||||
def idiv(self, x: Self | ConstType, reverse: bool = False) -> Self:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self // x`.
|
||||
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
||||
`idiv` performs integer division (truncate towards zero).
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
||||
```
|
||||
"""
|
||||
return self._binop(Ops.IDIV, x, reverse)
|
||||
|
||||
def mod(self, x: Self | ConstType, reverse: bool = False) -> Self:
|
||||
"""
|
||||
Mod `self` by `x`.
|
||||
|
|
@ -207,9 +194,35 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
|
|||
if dtypes.is_int(a.dtype): return a.alu(Ops.MOD, b)
|
||||
return a - (a*b.reciprocal()).trunc() * b
|
||||
|
||||
def div(self, x: Self | ConstType, reverse: bool = False) -> Self:
|
||||
def div(self, x: Self | ConstType, reverse: bool = False, rounding_mode: Literal["trunc", "floor"] | None = None) -> Self:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self / x`.
|
||||
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
||||
`div` performs true division by default; pass `rounding_mode="trunc"` for truncating toward zero
|
||||
or `rounding_mode="floor"` for floor division.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.div(3).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
||||
```
|
||||
"""
|
||||
lhs, rhs = self._broadcasted(x, reverse)
|
||||
return lhs * rhs.reciprocal()
|
||||
if rounding_mode is None: return lhs * rhs.reciprocal()
|
||||
if dtypes.is_int(lhs.dtype):
|
||||
if rounding_mode == "trunc": return lhs.alu(Ops.IDIV, rhs)
|
||||
if rounding_mode == "floor": return lhs // rhs
|
||||
d = lhs.cast(least_upper_float(lhs.dtype)) * rhs.cast(least_upper_float(rhs.dtype)).reciprocal()
|
||||
if rounding_mode == "trunc": return d.trunc()
|
||||
if rounding_mode == "floor": return d.floor()
|
||||
raise RuntimeError(f"{rounding_mode=} is not supported")
|
||||
|
||||
def __neg__(self) -> Self:
|
||||
return self.neg()
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_float, least_upper_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
|
|
@ -1284,35 +1284,6 @@ class Tensor(OpMixin):
|
|||
assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}"
|
||||
return Tensor(x, self.device, self.dtype if self._ufix_keep_dtype(x) else None, requires_grad=False)
|
||||
|
||||
def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self / x`.
|
||||
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
||||
`div` performs true division.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.div(3).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
||||
```
|
||||
"""
|
||||
if rounding_mode is None: return super().div(x, reverse) # type: ignore[arg-type]
|
||||
numerator, denominator = self._broadcasted(x, reverse)
|
||||
if dtypes.is_int(numerator.dtype):
|
||||
if rounding_mode == "trunc": return numerator.idiv(denominator)
|
||||
if rounding_mode == "floor": return numerator // denominator
|
||||
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
||||
if rounding_mode == "trunc": return d.trunc()
|
||||
if rounding_mode == "floor": return d.floor()
|
||||
raise RuntimeError(f"{rounding_mode=} is not supported")
|
||||
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
||||
"""
|
||||
Returns a tensor of elements selected from either `x` or `y`, depending on `self`.
|
||||
|
|
|
|||
|
|
@ -103,8 +103,8 @@ pm_pyrender_extra = PatternMatcher([
|
|||
# TODO: movement ops simplify stuff, this can break SPEC=2
|
||||
#(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
|
||||
# NOTE: CMPNE doesn't work cause there's no __rne__
|
||||
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render IDIV/MOD via their named methods
|
||||
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.idiv({ctx[x.src[1]]})"),
|
||||
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render IDIV/MOD via .alu()
|
||||
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.IDIV, {ctx[x.src[1]]})"),
|
||||
(UPat(Ops.MOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.MOD, {ctx[x.src[1]]})"),
|
||||
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
|
||||
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.IDIV, Ops.MOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), name="x"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue