mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
input validation for rand functions (#15990)
This commit is contained in:
parent
11e1a2b89f
commit
e0b09f288f
2 changed files with 16 additions and 1 deletions
|
|
@ -307,17 +307,26 @@ class TestRandomness(unittest.TestCase):
|
|||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float")
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
|
||||
# check low < high
|
||||
with self.assertRaises(ValueError): Tensor.randint((3, 4), low=10, high=5)
|
||||
with self.assertRaises(ValueError): Tensor.randint((3, 4), low=10, high=10)
|
||||
np.testing.assert_array_equal(Tensor.randint(16, low=5, high=6).numpy(), 5)
|
||||
|
||||
def test_normal(self):
|
||||
self.assertTrue(normal_test(Tensor.normal))
|
||||
self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1),
|
||||
lambda x: np.random.normal(loc=0, scale=1, size=x)))
|
||||
# check std >= 0
|
||||
with self.assertRaises(ValueError): Tensor.normal((3, 4), mean=0, std=-1)
|
||||
|
||||
def test_uniform(self):
|
||||
self.assertFalse(normal_test(Tensor.uniform))
|
||||
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x)), lambda x: np.random.uniform(size=x)))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.uniform, low=-100, high=100, dtype=dtypes.int32),
|
||||
numpy_func=lambda x: np.random.randint(low=-100, high=100, size=x)))
|
||||
# check low < high
|
||||
with self.assertRaises(ValueError): Tensor.uniform((3, 4), low=5.0, high=3.0)
|
||||
with self.assertRaises(ValueError): Tensor.uniform((3, 4), low=1.0, high=1.0)
|
||||
|
||||
def test_scaled_uniform(self):
|
||||
self.assertFalse(normal_test(Tensor.scaled_uniform))
|
||||
|
|
|
|||
|
|
@ -692,7 +692,7 @@ class Tensor(OpMixin):
|
|||
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
Requires `low < high`. If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
|
@ -704,12 +704,14 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
|
||||
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
|
||||
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
||||
Requires `std >= 0`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
|
@ -719,12 +721,14 @@ class Tensor(OpMixin):
|
|||
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
||||
```
|
||||
"""
|
||||
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
|
||||
return (std * Tensor.randn(*shape, **kwargs) + mean).requires_grad_(requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
||||
Requires `low < high`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
|
@ -734,6 +738,8 @@ class Tensor(OpMixin):
|
|||
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
|
||||
return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue