mirror of
https://github.com/shitagaki-lab/see-through.git
synced 2026-05-05 19:58:57 +00:00
24 lines
No EOL
788 B
Python
24 lines
No EOL
788 B
Python
import random
|
|
|
|
class NameSampler:
|
|
|
|
def __init__(self, name_prob_dict, sample_num=2048) -> None:
|
|
self.name_prob_dict = name_prob_dict
|
|
self._id2name = list(name_prob_dict.keys())
|
|
self.sample_ids = []
|
|
|
|
total_prob = 0.
|
|
for ii, (_, prob) in enumerate(name_prob_dict.items()):
|
|
tgt_num = int(prob * sample_num)
|
|
total_prob += prob
|
|
if tgt_num > 0:
|
|
self.sample_ids += [ii] * tgt_num
|
|
|
|
nsamples = len(self.sample_ids)
|
|
assert prob <= 1
|
|
if prob < 1 and nsamples < sample_num:
|
|
self.sample_ids += [len(self._id2name)] * (sample_num - nsamples)
|
|
self._id2name.append('_')
|
|
|
|
def sample(self) -> str:
|
|
return self._id2name[random.choice(self.sample_ids)] |