no replacement multinomial (#15995)

* no replacement multinomial

Efraimidis–Spirakis

* num_samples == 1 can use fast path
This commit is contained in:
chenyu 2026-04-30 17:35:26 -04:00 committed by GitHub
commit 52c92e15ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 31 additions and 6 deletions

View file

@ -361,7 +361,7 @@ class TestRandomness(unittest.TestCase):
_check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=300, replacement=True)
_check_with_torch(w=[[0.2, 0.8]], num_samples=300, replacement=True) # 2D but only 1 row
_check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=300, replacement=True)
# no-replacement isn't supported, unless taking only one sample
# no-replacement
w = [0.1, 0.9]
self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
@ -372,6 +372,23 @@ class TestRandomness(unittest.TestCase):
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
w = list(range(32))
s1 = Tensor(w).multinomial(5, replacement=False).numpy()
self.assertEqual(len(set(s1.tolist())), 5)
s2 = Tensor(w).multinomial(5, replacement=False).numpy()
self.assertFalse(np.array_equal(s1, s2))
full = Tensor(w).multinomial(len(w), replacement=False).numpy()
self.assertEqual(sorted(full.tolist()), w)
w = [0.1, 0.2, 0.3, 0.4]
@TinyJit
def sample_three(): return Tensor(w).multinomial(3, replacement=False).realize()
tiny_draws = np.array([sample_three().numpy() for _ in range(1000)])
torch_draws = np.array([torch.tensor(w).multinomial(3, replacement=False).numpy() for _ in range(1000)])
for pos in range(3):
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_draws[:, pos]), lambda _: torch.tensor(torch_draws[:, pos])))
@unittest.skip("this test is flaky")
def test_multinomial_counterexample(self):
tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(4000, replacement=True)

View file

@ -822,19 +822,27 @@ class Tensor(OpMixin):
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
NOTE: `replacement=False` for `num_samples > 1` is not supported yet.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# EfraimidisSpirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****