input validation for rand functions (#15990)

This commit is contained in:
chenyu 2026-04-30 14:00:44 -04:00 committed by GitHub
commit e0b09f288f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 1 deletions

View file

@ -307,17 +307,26 @@ class TestRandomness(unittest.TestCase):
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float")
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
# check low < high
with self.assertRaises(ValueError): Tensor.randint((3, 4), low=10, high=5)
with self.assertRaises(ValueError): Tensor.randint((3, 4), low=10, high=10)
np.testing.assert_array_equal(Tensor.randint(16, low=5, high=6).numpy(), 5)
def test_normal(self):
self.assertTrue(normal_test(Tensor.normal))
self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1),
lambda x: np.random.normal(loc=0, scale=1, size=x)))
# check std >= 0
with self.assertRaises(ValueError): Tensor.normal((3, 4), mean=0, std=-1)
def test_uniform(self):
self.assertFalse(normal_test(Tensor.uniform))
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x)), lambda x: np.random.uniform(size=x)))
self.assertTrue(equal_distribution(partial(Tensor.uniform, low=-100, high=100, dtype=dtypes.int32),
numpy_func=lambda x: np.random.randint(low=-100, high=100, size=x)))
# check low < high
with self.assertRaises(ValueError): Tensor.uniform((3, 4), low=5.0, high=3.0)
with self.assertRaises(ValueError): Tensor.uniform((3, 4), low=1.0, high=1.0)
def test_scaled_uniform(self):
self.assertFalse(normal_test(Tensor.scaled_uniform))

View file

@ -692,7 +692,7 @@ class Tensor(OpMixin):
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
If `dtype` is not specified, the default type is used.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
@ -704,12 +704,14 @@ class Tensor(OpMixin):
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
@ -719,12 +721,14 @@ class Tensor(OpMixin):
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return (std * Tensor.randn(*shape, **kwargs) + mean).requires_grad_(requires_grad)
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
@ -734,6 +738,8 @@ class Tensor(OpMixin):
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
@staticmethod