mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
no replacement multinomial (#15995)
* no replacement multinomial Efraimidis–Spirakis * num_samples == 1 can use fast path
This commit is contained in:
parent
e0b09f288f
commit
52c92e15ae
2 changed files with 31 additions and 6 deletions
|
|
@ -361,7 +361,7 @@ class TestRandomness(unittest.TestCase):
|
|||
_check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=300, replacement=True)
|
||||
_check_with_torch(w=[[0.2, 0.8]], num_samples=300, replacement=True) # 2D but only 1 row
|
||||
_check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=300, replacement=True)
|
||||
# no-replacement isn't supported, unless taking only one sample
|
||||
# no-replacement
|
||||
w = [0.1, 0.9]
|
||||
self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
|
||||
|
||||
|
|
@ -372,6 +372,23 @@ class TestRandomness(unittest.TestCase):
|
|||
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
|
||||
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
|
||||
|
||||
w = list(range(32))
|
||||
s1 = Tensor(w).multinomial(5, replacement=False).numpy()
|
||||
self.assertEqual(len(set(s1.tolist())), 5)
|
||||
s2 = Tensor(w).multinomial(5, replacement=False).numpy()
|
||||
self.assertFalse(np.array_equal(s1, s2))
|
||||
full = Tensor(w).multinomial(len(w), replacement=False).numpy()
|
||||
self.assertEqual(sorted(full.tolist()), w)
|
||||
|
||||
w = [0.1, 0.2, 0.3, 0.4]
|
||||
@TinyJit
|
||||
def sample_three(): return Tensor(w).multinomial(3, replacement=False).realize()
|
||||
|
||||
tiny_draws = np.array([sample_three().numpy() for _ in range(1000)])
|
||||
torch_draws = np.array([torch.tensor(w).multinomial(3, replacement=False).numpy() for _ in range(1000)])
|
||||
for pos in range(3):
|
||||
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_draws[:, pos]), lambda _: torch.tensor(torch_draws[:, pos])))
|
||||
|
||||
@unittest.skip("this test is flaky")
|
||||
def test_multinomial_counterexample(self):
|
||||
tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(4000, replacement=True)
|
||||
|
|
|
|||
|
|
@ -822,19 +822,27 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
|
||||
|
||||
NOTE: `replacement=False` for `num_samples > 1` is not supported yet.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.multinomial(20, replacement=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.multinomial(3, replacement=False).numpy())
|
||||
```
|
||||
"""
|
||||
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
||||
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
|
||||
if replacement or num_samples == 1:
|
||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
else:
|
||||
# Efraimidis–Spirakis
|
||||
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue