RandMixin [PR] (#16543)

This commit is contained in:
chenyu 2026-06-08 19:11:28 -04:00 committed by GitHub
commit 11fee53527
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 46 additions and 34 deletions

View file

@ -222,6 +222,12 @@ class TestTensorUOpRand(unittest.TestCase):
bits_uop = UOp.empty((8,), dtype=dtypes.uint32)
for shape in ((8,), (2, 4), (5,)):
self.assertIs(Tensor._bits_to_rand(Tensor(bits_uop), shape, dtypes.float32).uop, UOp._bits_to_rand(bits_uop, shape, dtypes.float32))
def test_threefry(self):
t = _t(4).cast(dtypes.uint64)
self.assertIs(t.threefry(t).uop, t.uop.threefry(t.uop))
def test_threefry_random_bits(self):
key, c0, c1 = UOp.empty((2,), dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32)
self.assertIs(Tensor._threefry_random_bits(Tensor(key), Tensor(c0), Tensor(c1)).uop, UOp._threefry_random_bits(key, c0, c1))
class TestTensorUOpGather(unittest.TestCase):
def _check(self, t, dim, idx):

View file

@ -6,7 +6,7 @@ from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.dtype import ConstType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
from tinygrad.helpers import resolve_pool_pads, round_up
@ -308,36 +308,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
"""
return self._tri(self.shape[-2], self.shape[-1], diagonal+1).where(self.const_like(0), self)
# ***** random *****
@staticmethod
def _threefry_random_bits(key, counts0, counts1):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = x.threefry((key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
return (x & 0xffffffff).cast(dtypes.uint32).cat(((x >> 32) & 0xffffffff).cast(dtypes.uint32))
@classmethod
def random_bits(cls, key:Self, counter:Self, num:int) -> Self:
low, high = counter[0:1], counter[1:2]
bits = []
for i in range(0, num, dtypes.uint32.max):
chunk_num = min(num - i, dtypes.uint32.max)
c_low = low + (i & 0xffffffff)
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
new_key = cls._threefry_random_bits(key, c_low, c_high)
counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32)
counts1 = counts0 + ceildiv(chunk_num, 2)
bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
return bits[0].cat(*bits[1:])
@staticmethod
def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType):
_, nmant = dtypes.finfo(dtype)
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
uint_bits = bits.bitcast(uint_dtype)
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
def _pad_constant(self, pX, value:ConstType) -> Self:
# shrink first for negative pads, then pad with only non-negative values
pX = tuple((0, 0) if p is None else p for p in pX)

35
tinygrad/mixin/rand.py Normal file
View file

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import Self
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import ceildiv, prod
from tinygrad.mixin import OpMixin
class RandMixin(OpMixin):
@staticmethod
def _threefry_random_bits(key, counts0, counts1):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = x.threefry((key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
return (x & 0xffffffff).cast(dtypes.uint32).cat(((x >> 32) & 0xffffffff).cast(dtypes.uint32))
@classmethod
def random_bits(cls, key:Self, counter:Self, num:int) -> Self:
low, high = counter[0:1], counter[1:2]
bits = []
for i in range(0, num, dtypes.uint32.max):
chunk_num = min(num - i, dtypes.uint32.max)
c_low = low + (i & 0xffffffff)
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
new_key = cls._threefry_random_bits(key, c_low, c_high)
counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32)
counts1 = counts0 + ceildiv(chunk_num, 2)
bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
return bits[0].cat(*bits[1:])
@staticmethod
def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType):
_, nmant = dtypes.finfo(dtype)
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
uint_bits = bits.bitcast(uint_dtype)
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)

View file

@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, f
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.mixin import OpMixin
from tinygrad.mixin.rand import RandMixin
from tinygrad.schedule import create_linear_with_vars
from tinygrad.device import Buffer, canonicalize_device
from tinygrad.engine.realize import run_linear
@ -59,7 +59,7 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
class Tensor(OpMixin):
class Tensor(RandMixin):
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.

View file

@ -135,10 +135,11 @@ class recursive_property(property):
# we import this late so we can use resolve/smax in mixins
from tinygrad.mixin import OpMixin
from tinygrad.mixin.rand import RandMixin
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(OpMixin, metaclass=UOpMetaClass):
class UOp(RandMixin, metaclass=UOpMetaClass):
op:Ops
dtype:DType = dtypes.void
src:tuple[UOp, ...] = tuple()