div to mixin (#16078)

also deleted idiv method
This commit is contained in:
chenyu 2026-05-07 12:52:37 -04:00 committed by GitHub
commit 072db9924c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 48 additions and 63 deletions

View file

@ -66,7 +66,6 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.sub
::: tinygrad.Tensor.mul
::: tinygrad.Tensor.div
::: tinygrad.Tensor.idiv
::: tinygrad.Tensor.mod
::: tinygrad.Tensor.fmod
::: tinygrad.Tensor.bitwise_xor

View file

@ -123,7 +123,7 @@ def NF4Linear(block_size):
def __call__(self, x: Tensor) -> Tensor:
high_bits = self.weight
low_bits = (self.weight * 2 ** 4).contiguous()
unpacked = Tensor.stack(high_bits, low_bits, dim=-1).idiv(2 ** 4)
unpacked = Tensor.stack(high_bits, low_bits, dim=-1).div(2 ** 4, rounding_mode="trunc")
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)

View file

@ -606,10 +606,11 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]])
helper_test_op(None, lambda x: x/2, forward_only=True, vals=[[3, 4, 5]])
helper_test_op(None, lambda x: x//2, forward_only=True, vals=[[3, 4, 5]])
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"),
functools.partial(Tensor.div, rounding_mode="trunc"), forward_only=True,
vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
if not COMPILE_ONLY:
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
x = Tensor(2**64 - 1, dtype=dtypes.uint64).div(1, rounding_mode="trunc")
np.testing.assert_equal(x.numpy(), 2**64 - 1)
def test_scalar_div(self):
@ -878,10 +879,10 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor >> 31, lambda: ten >> 31, forward_only=True)
def test_idiv_shift_rewrite_negative(self):
a = Tensor(-5).idiv(2).item()
b = Tensor(-5).contiguous().idiv(2).item()
a = Tensor(-5).div(2, rounding_mode="trunc").item()
b = Tensor(-5).contiguous().div(2, rounding_mode="trunc").item()
self.assertEqual(a, b)
self.assertEqual(Tensor(-1).contiguous().idiv(4).item(), 0) # NOTE this is trunc-div behaviour
self.assertEqual(Tensor(-1).contiguous().div(4, rounding_mode="trunc").item(), 0) # NOTE this is trunc-div behaviour
@unittest.skipIf(DEV.renderer == "NAK", "MUFU.SIN is not accurate enough")
def test_sin(self):

View file

@ -957,7 +957,7 @@ class TestSchedule(unittest.TestCase):
def test_div_padded_arange(self):
x = Tensor.full((2,2), 16)
y = x.idiv(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2)).pad(((1,1), (1,1)))
y = x.div(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2), rounding_mode="trunc").pad(((1,1), (1,1)))
out = y.sum(axis=1)
run_linear(*check_schedule(out, 1))
self.assertListEqual(out.tolist(), [0, 12, 4, 0])

View file

@ -79,9 +79,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
def test_div_tensor_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
def test_idiv_literal_one(self):
def test_floordiv_literal_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
def test_idiv_tensor_one(self):
def test_floordiv_tensor_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
def test_pow_literal_zero(self):

View file

@ -35,13 +35,14 @@ class TestTensorUOpBinop(unittest.TestCase):
def test_isclose(self):
t = _t(4).float()
self.assertIs(_strip_unique(t.isclose(t).uop), _strip_unique(t.uop.isclose(t.uop)))
# __floordiv__/mod/idiv/fmod dispatch on dtype in mixin
# __floordiv__/mod/fmod and div(rounding_mode=...) dispatch on dtype in mixin
def test_floordiv_int(self): _check(self, _t(4), lambda x: x // 3)
def test_floordiv_float(self): _check(self, _t(4).float() + 1.5, lambda x: x // 2.0)
def test_rfloordiv_int(self): _check(self, _t(4)+1, lambda x: 7 // x)
def test_mod_int(self): _check(self, _t(4), lambda x: x % 3)
def test_mod_float(self): _check(self, _t(4).float() + 1.5, lambda x: x % 2.0)
def test_idiv_int(self): _check(self, _t(4), lambda x: x.idiv(3))
def test_div_trunc_int(self): _check(self, _t(4), lambda x: x.div(3, rounding_mode="trunc"))
def test_div_trunc_float(self):_check(self, _t(4).float() + 1.5, lambda x: x.div(2.0, rounding_mode="trunc"))
def test_fmod_int(self): _check(self, _t(4), lambda x: x.fmod(3))
def test_fmod_float(self): _check(self, _t(4).float() + 1.5, lambda x: x.fmod(2.0))
def test_floordiv_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x // True)

View file

@ -53,7 +53,7 @@ class TestValidateOOB(unittest.TestCase):
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
# ALU ops in index
def test_idiv(self):
def test_floordiv(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid

View file

@ -34,7 +34,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, rounding_mode="trunc").bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
# map to (number of elements, number of bytes)
if (nelements_nbytes := {

View file

@ -1,5 +1,5 @@
import math, functools, operator
from typing import Self
from typing import Literal, Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
from tinygrad.helpers import argfix, polyN
@ -167,19 +167,6 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
self._check_dtype()
return self._binop(Ops.XOR, x, reverse)
def idiv(self, x: Self | ConstType, reverse: bool = False) -> Self:
"""
Divides `self` by `x`.
Equivalent to `self // x`.
Supports broadcasting to a common shape, type promotion, and integer inputs.
`idiv` performs integer division (truncate towards zero).
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
```
"""
return self._binop(Ops.IDIV, x, reverse)
def mod(self, x: Self | ConstType, reverse: bool = False) -> Self:
"""
Mod `self` by `x`.
@ -207,9 +194,35 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
if dtypes.is_int(a.dtype): return a.alu(Ops.MOD, b)
return a - (a*b.reciprocal()).trunc() * b
def div(self, x: Self | ConstType, reverse: bool = False) -> Self:
def div(self, x: Self | ConstType, reverse: bool = False, rounding_mode: Literal["trunc", "floor"] | None = None) -> Self:
"""
Divides `self` by `x`.
Equivalent to `self / x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
`div` performs true division by default; pass `rounding_mode="trunc"` for truncating toward zero
or `rounding_mode="floor"` for floor division.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.div(3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
```
"""
lhs, rhs = self._broadcasted(x, reverse)
return lhs * rhs.reciprocal()
if rounding_mode is None: return lhs * rhs.reciprocal()
if dtypes.is_int(lhs.dtype):
if rounding_mode == "trunc": return lhs.alu(Ops.IDIV, rhs)
if rounding_mode == "floor": return lhs // rhs
d = lhs.cast(least_upper_float(lhs.dtype)) * rhs.cast(least_upper_float(rhs.dtype)).reciprocal()
if rounding_mode == "trunc": return d.trunc()
if rounding_mode == "floor": return d.floor()
raise RuntimeError(f"{rounding_mode=} is not supported")
def __neg__(self) -> Self:
return self.neg()

View file

@ -2,9 +2,9 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, ParamSpec, TypeVar, Generic, TYPE_CHECKING
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_float, least_upper_dtype, to_dtype, truncate
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
@ -1284,35 +1284,6 @@ class Tensor(OpMixin):
assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}"
return Tensor(x, self.device, self.dtype if self._ufix_keep_dtype(x) else None, requires_grad=False)
def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self / x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
`div` performs true division.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.div(3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
```
"""
if rounding_mode is None: return super().div(x, reverse) # type: ignore[arg-type]
numerator, denominator = self._broadcasted(x, reverse)
if dtypes.is_int(numerator.dtype):
if rounding_mode == "trunc": return numerator.idiv(denominator)
if rounding_mode == "floor": return numerator // denominator
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
if rounding_mode == "trunc": return d.trunc()
if rounding_mode == "floor": return d.floor()
raise RuntimeError(f"{rounding_mode=} is not supported")
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
"""
Returns a tensor of elements selected from either `x` or `y`, depending on `self`.

View file

@ -103,8 +103,8 @@ pm_pyrender_extra = PatternMatcher([
# TODO: movement ops simplify stuff, this can break SPEC=2
#(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
# NOTE: CMPNE doesn't work cause there's no __rne__
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render IDIV/MOD via their named methods
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.idiv({ctx[x.src[1]]})"),
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render IDIV/MOD via .alu()
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.IDIV, {ctx[x.src[1]]})"),
(UPat(Ops.MOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.MOD, {ctx[x.src[1]]})"),
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.IDIV, Ops.MOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), name="x"),