VTS/vts/utils/mask.py
2026-06-12 23:35:56 +09:00

101 lines
3.8 KiB
Python

import numpy as np
import scipy.special
import scipy.stats
import torch
from einops import rearrange, repeat
from utils.typing import EncMaskTensor, LengthTensor
def prob_mask_like(
shape: torch.Size | tuple[int, ...] | list[int], prob: float, device: torch.device
):
match prob:
case 1:
return torch.ones(shape, device=device, dtype=torch.bool)
case 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
case _:
return torch.rand(shape, device=device) < prob
def mask_from_fracs(lengths: LengthTensor, fmin: float, fmax: float, max_len: int):
batch = lengths.shape[0]
fracs = torch.zeros(batch, device=lengths.device).float().uniform_(fmin, fmax)
frac_lengths = (lengths.float() * fracs).clamp(min=1).round().long()
max_starts = lengths - frac_lengths
start_fracs = torch.zeros(batch, device=lengths.device).float().uniform_(0, 1)
starts = (max_starts * start_fracs).round().long()
ends = starts + frac_lengths
seq = torch.arange(max_len, device=lengths.device)
seq = repeat(seq, "l -> b l", b=lengths.shape[0])
mask = (seq >= rearrange(starts, "b -> b ()")) & (
seq < rearrange(ends, "b -> b ()")
)
return mask
def mask_from_lengths(lengths: LengthTensor, max_len: int) -> EncMaskTensor:
seq = torch.arange(max_len, device=lengths.device)
mask = seq < rearrange(lengths, "b -> b ()")
return mask
def constant_slice_mask(
lengths: LengthTensor, shape: int, max_len: int
) -> EncMaskTensor:
max_start = lengths - lengths.clamp(max=max_len)
starts = torch.rand(lengths.shape[0], device=lengths.device)
starts = (starts * max_start).round().long()
ends = starts + max_len
seq = torch.arange(shape, device=lengths.device)
seq = repeat(seq, "l -> b l", b=lengths.shape[0])
mask = (seq >= rearrange(starts, "b -> b ()")) & (
seq < rearrange(ends, "b -> b ()")
)
return mask
def num_windows(perc: float, length: int, min_span: int):
max_window = int(np.floor(min(perc * length / min_span, (1 - perc) * length + 1)))
if max_window == 1:
return 1
count = np.arange(1, max_window)
op1 = scipy.special.loggamma(perc * length - count * min_span + count)
op2 = scipy.special.loggamma(perc * length - count * min_span + 1)
op3 = scipy.special.loggamma(count)
op4 = scipy.special.loggamma((1 - perc) * length)
op5 = scipy.special.loggamma((1 - perc) * length - count)
op6 = scipy.special.loggamma(count + 1)
log_prob = op1 - op2 - op3 + op4 - op5 - op6
log_prob = scipy.special.softmax(log_prob, axis=0)
rng = np.random.default_rng()
return rng.choice(count.shape[0], p=log_prob) + 1
def min_span_mask(length: int, fmin: float, fmax: float, min_span: int):
rng = np.random.default_rng()
if min_span * 2 > length:
min_span = round(length / 2)
frac = rng.uniform(fmin, fmax)
frac_length = round(frac * length)
windows = num_windows(frac, length, min_span)
window_sum = frac_length - windows * min_span + windows - 1
blank_sum = length - frac_length + 1
window_lengths = rng.choice(window_sum, windows - 1, replace=False)
blank_lengths = rng.choice(blank_sum, windows, replace=False)
window_lengths = np.concatenate([window_lengths, np.array([-1, window_sum])])
blank_lengths = np.concatenate([blank_lengths, np.array([-1, blank_sum])])
window_lengths.sort()
blank_lengths.sort()
window_lengths = np.diff(window_lengths) - 1
blank_lengths = np.diff(blank_lengths) - 1
window_lengths += min_span
blank_lengths[1:-1] += 1
total_lengths = np.zeros((2 * windows + 1), dtype=np.int64)
total_lengths[1::2] = window_lengths
total_lengths[::2] = blank_lengths
binary_segments = np.arange(2 * windows + 1) % 2
return binary_segments.repeat(total_lengths).astype(bool)