mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
RandMixin [PR] (#16543)
This commit is contained in:
parent
e2ef5cf5c9
commit
11fee53527
5 changed files with 46 additions and 34 deletions
|
|
@ -222,6 +222,12 @@ class TestTensorUOpRand(unittest.TestCase):
|
|||
bits_uop = UOp.empty((8,), dtype=dtypes.uint32)
|
||||
for shape in ((8,), (2, 4), (5,)):
|
||||
self.assertIs(Tensor._bits_to_rand(Tensor(bits_uop), shape, dtypes.float32).uop, UOp._bits_to_rand(bits_uop, shape, dtypes.float32))
|
||||
def test_threefry(self):
|
||||
t = _t(4).cast(dtypes.uint64)
|
||||
self.assertIs(t.threefry(t).uop, t.uop.threefry(t.uop))
|
||||
def test_threefry_random_bits(self):
|
||||
key, c0, c1 = UOp.empty((2,), dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32)
|
||||
self.assertIs(Tensor._threefry_random_bits(Tensor(key), Tensor(c0), Tensor(c1)).uop, UOp._threefry_random_bits(key, c0, c1))
|
||||
|
||||
class TestTensorUOpGather(unittest.TestCase):
|
||||
def _check(self, t, dim, idx):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tinygrad.mixin.movement import MovementMixin
|
|||
from tinygrad.mixin.reduce import ReduceMixin
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.dtype import ConstType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
|
||||
from tinygrad.helpers import resolve_pool_pads, round_up
|
||||
|
||||
|
|
@ -308,36 +308,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
"""
|
||||
return self._tri(self.shape[-2], self.shape[-1], diagonal+1).where(self.const_like(0), self)
|
||||
|
||||
# ***** random *****
|
||||
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key, counts0, counts1):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = x.threefry((key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
return (x & 0xffffffff).cast(dtypes.uint32).cat(((x >> 32) & 0xffffffff).cast(dtypes.uint32))
|
||||
|
||||
@classmethod
|
||||
def random_bits(cls, key:Self, counter:Self, num:int) -> Self:
|
||||
low, high = counter[0:1], counter[1:2]
|
||||
bits = []
|
||||
for i in range(0, num, dtypes.uint32.max):
|
||||
chunk_num = min(num - i, dtypes.uint32.max)
|
||||
c_low = low + (i & 0xffffffff)
|
||||
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
|
||||
new_key = cls._threefry_random_bits(key, c_low, c_high)
|
||||
counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32)
|
||||
counts1 = counts0 + ceildiv(chunk_num, 2)
|
||||
bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
|
||||
return bits[0].cat(*bits[1:])
|
||||
|
||||
@staticmethod
|
||||
def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType):
|
||||
_, nmant = dtypes.finfo(dtype)
|
||||
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
|
||||
uint_bits = bits.bitcast(uint_dtype)
|
||||
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
|
||||
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
|
||||
|
||||
def _pad_constant(self, pX, value:ConstType) -> Self:
|
||||
# shrink first for negative pads, then pad with only non-negative values
|
||||
pX = tuple((0, 0) if p is None else p for p in pX)
|
||||
|
|
|
|||
35
tinygrad/mixin/rand.py
Normal file
35
tinygrad/mixin/rand.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from __future__ import annotations
|
||||
from typing import Self
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import ceildiv, prod
|
||||
from tinygrad.mixin import OpMixin
|
||||
|
||||
|
||||
class RandMixin(OpMixin):
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key, counts0, counts1):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = x.threefry((key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
return (x & 0xffffffff).cast(dtypes.uint32).cat(((x >> 32) & 0xffffffff).cast(dtypes.uint32))
|
||||
|
||||
@classmethod
|
||||
def random_bits(cls, key:Self, counter:Self, num:int) -> Self:
|
||||
low, high = counter[0:1], counter[1:2]
|
||||
bits = []
|
||||
for i in range(0, num, dtypes.uint32.max):
|
||||
chunk_num = min(num - i, dtypes.uint32.max)
|
||||
c_low = low + (i & 0xffffffff)
|
||||
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
|
||||
new_key = cls._threefry_random_bits(key, c_low, c_high)
|
||||
counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32)
|
||||
counts1 = counts0 + ceildiv(chunk_num, 2)
|
||||
bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
|
||||
return bits[0].cat(*bits[1:])
|
||||
|
||||
@staticmethod
|
||||
def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType):
|
||||
_, nmant = dtypes.finfo(dtype)
|
||||
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
|
||||
uint_bits = bits.bitcast(uint_dtype)
|
||||
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
|
||||
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
|
||||
|
|
@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, f
|
|||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.mixin.rand import RandMixin
|
||||
from tinygrad.schedule import create_linear_with_vars
|
||||
from tinygrad.device import Buffer, canonicalize_device
|
||||
from tinygrad.engine.realize import run_linear
|
||||
|
|
@ -59,7 +59,7 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
class Tensor(OpMixin):
|
||||
class Tensor(RandMixin):
|
||||
"""
|
||||
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
||||
|
||||
|
|
|
|||
|
|
@ -135,10 +135,11 @@ class recursive_property(property):
|
|||
|
||||
# we import this late so we can use resolve/smax in mixins
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.mixin.rand import RandMixin
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
op:Ops
|
||||
dtype:DType = dtypes.void
|
||||
src:tuple[UOp, ...] = tuple()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue