just get dtype from kwargs (#4355)

This commit is contained in:
Sohaib 2024-04-30 00:26:14 -07:00 committed by GitHub
commit a2d81514fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -514,7 +514,8 @@ class Tensor:
t = Tensor.randint(2, 3, low=5, high=10)
print(t.numpy())
"""
return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
assert dtypes.is_int(dtype := kwargs.pop("dtype", dtypes.int32)), f"Unsupported dtype {dtype} for randint"
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: