VTS/vts/model/module_voice.py
2026-06-12 23:35:56 +09:00

627 lines
22 KiB
Python

import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
import random
import numpy as np
import torch
import torchaudio
from einops import rearrange, repeat
from hydra.core.hydra_config import HydraConfig
from jaxtyping import Float
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.utilities.types import (
LRSchedulerConfigType,
OptimizerLRSchedulerConfig,
)
from matplotlib import pyplot as plt
from torch import Tensor
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torchdiffeq import odeint
from transformers import AutoTokenizer, T5EncoderModel
from vocos_custom import get_voco
from model.vts_voice import VTS
from model.loss import masked_loss
from torchode.interface import solve_ivp
from utils.mask import min_span_mask, prob_mask_like
from utils.typing import AudioTensor, EncMaskTensor, EncTensor, LossTensor
from utils.utils import plot_with_cmap, write_html, get_dynamic
try:
from lightning.pytorch.loggers.mlflow import MLFlowLogger
except Exception:
MLFlowLogger = ()
try:
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
except Exception:
TensorBoardLogger = ()
class VTSModule(LightningModule):
def __init__(
self,
dim: int,
depth: int,
heads: int,
attn_dropout: float,
ff_dropout: float,
kernel_size: int,
voco_type: str,
max_audio_len: int,
optimizer: str = "Adam",
lr: float = 1e-4,
scheduler: str = "linear_warmup_decay",
use_torchode=True,
torchdiffeq_ode_method="midpoint",
torchode_method_klass="tsit5",
max_steps: int = 1000000,
text_repo_id: str = "google/flan-t5-base",
):
super().__init__()
self.save_hyperparameters()
self.voco_type = voco_type
print("voco type : ", voco_type)
voco = get_voco(self.voco_type)
self.sampling_rate = voco.sampling_rate
self.t5 = T5EncoderModel.from_pretrained(text_repo_id)
self.t5.eval()
for param in self.t5.parameters():
param.requires_grad_(False)
phoneme_dim = self.t5.config.d_model
self.vts = VTS(
audio_dim=voco.latent_dim,
phoneme_dim=phoneme_dim,
dim=dim,
depth=depth,
heads=heads,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
kernel_size=kernel_size,
)
self.mask_fracs = (0.7, 1.0)
self.min_span = 10
self.drop_prob = 0.4
self.max_audio_len = max_audio_len
self.use_torchode = use_torchode
self.torchode_method_klass = torchode_method_klass
self.steps = 64
self.sigma = 1e-5
self.method = torchdiffeq_ode_method
self.optim = optimizer
self.lr = lr
self.scheduler = scheduler
self.max_steps = max_steps
self.debug_logger = logging.getLogger("vts")
self.debug_logger.setLevel(logging.DEBUG)
try:
output_dir = HydraConfig.get().runtime.output_dir
except Exception:
output_dir = "outputs"
if not Path(output_dir).exists():
Path(output_dir).mkdir()
handler = logging.FileHandler(f"{output_dir}/vts.log")
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
self.debug_logger.addHandler(handler)
def load_state_dict(self, state_dict, strict=True, assign=False):
old_prefix = "audio" + "box."
if any(key.startswith(old_prefix) for key in state_dict):
state_dict = {
f"vts.{key[len(old_prefix):]}" if key.startswith(old_prefix) else key: value
for key, value in state_dict.items()
}
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def solve(
self,
context: EncTensor,
mask: EncMaskTensor,
phoneme: Tensor,
phoneme_mask: Tensor,
voice_enc: Tensor,
alpha=0.0,
) -> EncTensor:
with torch.autocast(device_type=self.device.type, enabled=False):
phoneme_emb = self.t5(
input_ids=phoneme, attention_mask=phoneme_mask
).last_hidden_state
def fn(t: Float[Tensor, "..."], y: Float[Tensor, "..."]):
out = self.vts.cfg(
w=y,
context=context,
times=t,
alpha=alpha,
mask=mask,
phoneme_emb=phoneme_emb,
phoneme_mask=phoneme_mask,
voice_enc=voice_enc
)
return out
y0 = torch.randn_like(context)
t = torch.linspace(0, 1, self.steps, device=self.device)
if self.use_torchode:
batch = context.shape[0]
t = repeat(t, "n -> b n", b=batch)
sol = solve_ivp(
torch.compile(fn, dynamic=False),
y0,
t,
method_class=self.torchode_method_klass,
)
sampled = sol.ys[-1]
else:
trajectory = odeint(
fn,
y0,
t,
atol=self.atol,
rtol=self.rtol,
method=self.method,
options=dict(step_size=1 / self.steps),
)
sampled = trajectory[-1]
return sampled
@torch.compiler.disable
@torch.no_grad
def get_span_mask(self, audio_mask: EncMaskTensor):
audio_lens = audio_mask.sum(dim=1).detach().cpu().numpy()
mask_len = audio_mask.shape[-1]
span_mask = pad_sequence(
[
torch.from_numpy(
min_span_mask(
int(audio_len),
fmin=self.mask_fracs[0],
fmax=self.mask_fracs[1],
min_span=self.min_span,
)
).to(self.device)
for audio_len in audio_lens
],
batch_first=True,
)
return F.pad(span_mask, (0, self.max_audio_len - span_mask.shape[1]))
@torch.compiler.disable
@torch.no_grad
def get_span_mask_cond(self, audio_mask: EncMaskTensor):
audio_lens = audio_mask.sum(dim=1).detach().cpu().numpy()
span_mask = pad_sequence(
[
torch.from_numpy(
min_span_mask(
int(audio_len),
fmin=0.5,
fmax=1.0,
min_span=2,
)
).to(self.device)
for audio_len in audio_lens
],
batch_first=True,
)
return F.pad(span_mask, (0, self.max_audio_len - span_mask.shape[1]))
def mask_voice_enc(self, voice_enc, mask_prob=0.1):
"""
voice_enc: [B, 400, 12] 모양의 텐서.
마지막 12 차원을 3개의 그룹(4,4,4)으로 나누고,
각 그룹에 대해 10% 확률로 전체 [400,4] 영역을 0으로 마스킹합니다.
반환:
- masked_voice_enc: 마스킹된 텐서.
- mask_flags: 배치별 각 그룹이 마스킹되었는지에 대한 boolean 텐서 ([B, 3]).
"""
B, L, D = voice_enc.shape # L=400, D=12
group_size = 4
n_groups = D // group_size # 3 그룹으로 가정
# 배치별 각 그룹에 대해 10% 확률 마스킹 여부 결정 (True면 마스킹)
mask_flags = torch.rand(B, n_groups) < mask_prob # shape: [B, 3]
# voice_enc를 복사하여 in-place 수정하지 않도록 함
masked_voice_enc = voice_enc.clone()
# 각 배치와 그룹에 대해 마스킹 수행
for b in range(B):
for g in range(n_groups):
if mask_flags[b, g]:
# 그룹 g의 인덱스 범위: g*4 ~ (g+1)*4
masked_voice_enc[b, :, g * group_size:(g + 1) * group_size] = 0.0
return masked_voice_enc
def sample(
self,
audio_enc: AudioTensor,
audio_mask: EncMaskTensor,
phoneme: Tensor,
phoneme_mask: Tensor,
voice_enc: AudioTensor,
alpha=0.0,
):
span_mask = self.get_span_mask(audio_mask)
span_mask = torch.ones_like(span_mask)
audio_context = torch.where(rearrange(span_mask, "b l -> b l ()"), 0, audio_enc)
sampled_audio_enc = self.solve(
audio_context, audio_mask, phoneme, phoneme_mask, voice_enc=voice_enc, alpha=alpha
)
return sampled_audio_enc
def forward(
self,
audio_enc: AudioTensor,
audio_mask: EncMaskTensor,
phoneme: Tensor,
phoneme_mask: Tensor,
voice_enc: AudioTensor
) -> LossTensor:
try:
batch = audio_enc.shape[0]
with torch.no_grad():
span_mask = self.get_span_mask(audio_mask)
with torch.autocast(device_type=self.device.type, enabled=False):
phoneme_emb = self.t5(
input_ids=phoneme, attention_mask=phoneme_mask
).last_hidden_state
audio_x0 = torch.randn_like(audio_enc)
times = torch.rand((batch,), dtype=audio_enc.dtype, device=self.device)
t = rearrange(times, "b -> b () ()")
w = (1 - (1 - self.sigma) * t) * audio_x0 + t * audio_enc
cond_drop_mask = prob_mask_like((batch, 1), self.drop_prob, self.device)
audio_cond_mask = span_mask | cond_drop_mask
audio_context = torch.where(
rearrange(audio_cond_mask, "b l -> b l ()"), 0, audio_enc
)
phon_drop_mask = prob_mask_like((batch,), self.drop_prob, self.device)
phoneme_emb = torch.where(
rearrange(phon_drop_mask, "b -> b () ()"), 0, phoneme_emb
)
B, T, D = voice_enc.shape
if random.random()<0.5:
try:
span_mask = self.get_span_mask_cond(audio_mask)
span_mask = rearrange(span_mask, "b l -> b l 1").expand(B, T, D)
voice_enc = torch.where(span_mask, torch.tensor(0.0, device=voice_enc.device), voice_enc)
except Exception as e:
print("\n\n\nexception ", e, "\n\n\n")
voice_enc = self.mask_voice_enc(voice_enc, mask_prob=0.2)
pred_audio_flow = self.vts(
w=w,
times=times,
audio_mask=audio_mask,
context=audio_context,
phoneme_emb=phoneme_emb,
phoneme_mask=phoneme_mask,
voice_enc=voice_enc
)
target_audio_flow = audio_enc - (1 - self.sigma) * audio_x0
loss = masked_loss(pred_audio_flow, target_audio_flow, audio_cond_mask, "mse")
# NaN 체크
if loss is None or torch.isnan(loss):
self.print(f"[step {self.global_step}] ❌ NaN loss detected, skipping step.")
dummy = pred_audio_flow.sum() * 0
return dummy
return loss
except Exception as e:
self.print(f"[step {self.global_step}] ❌ Exception in training_step: {e}")
return None # 이 step은 스킵됨
@torch.compiler.disable
def log_loss(self, id: str, loss: Tensor | float, train: bool):
if train:
self.log(id, loss, on_step=True, on_epoch=False, logger=True)
else:
self.log(id, loss, on_step=False, on_epoch=True, sync_dist=True, logger=True)
def single_step(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], prefix: str
) -> LossTensor:
audio, audio_mask, phoneme, phoneme_mask, voice_enc = batch
loss = self(audio, audio_mask, phoneme, phoneme_mask, voice_enc=voice_enc)
train = prefix == "train"
self.log_loss(f"{prefix}/loss", loss, train)
return loss
def training_step(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], batch_idx: int
):
return self.single_step(batch, "train")
def validation_step(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], batch_idx: int
):
self.single_step(batch, "val")
if batch_idx < 5:
self.log_table(batch, "val", batch_idx)
# self.validate_generation(batch, batch_idx)
def validate_generation(self, batch, idx):
audios = [
'./voice_samples/piung.wav',
'./voice_samples/beepbeep.m4a',
'./voice_samples/charging.m4a',
]
voices = [
'./voice_samples/piung_voice.npy',
'./voice_samples/beepbeep_voice.npy',
'./voice_samples/charging_voice.npy',
]
texts = [
'./voice_samples/piung_token.npy',
'./voice_samples/beepbeep_token.npy',
'./voice_samples/charging_token.npy',
]
text_masks = [
'./voice_samples/piung_token_mask.npy',
'./voice_samples/beepbeep_token_mask.npy',
'./voice_samples/charging_token_mask.npy',
]
audio_len = [
int(3.7*21.5),
int(3*21.5),
int(4*21.5),
]
waveform, sr = torchaudio.load(audios[idx])
if sr != self.sampling_rate:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=self.sampling_rate)
waveform = waveform.mean(dim=0, keepdim=True)
dynamic_context = get_dynamic(waveform, 400).unsqueeze(dim=0).to(self.device)
audio_enc = batch[0][0].unsqueeze(dim=0).to(self.device)
audio_mask = torch.cat((
torch.ones((1, audio_len[idx])),
torch.zeros((1, self.max_audio_len - audio_len[idx]))
), dim=-1) > 0
audio_mask = audio_mask.to(self.device)
# audio_mask = batch[1][0].unsqueeze(dim=0).to(self.device) # 1, 400
text_embed = torch.from_numpy(np.load(texts[idx])).unsqueeze(dim=0).to(self.device)
text_mask = torch.from_numpy(np.load(text_masks[idx])).unsqueeze(dim=0).to(self.device)
self.log_table(
(audio_enc, audio_mask, text_embed, text_mask, dynamic_context),
'val',
idx
)
def test_step(self, batch: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], batch_idx: int):
self.single_step(batch, "test")
if batch_idx < 5:
self.log_table(batch, "test", batch_idx)
@rank_zero_only
@torch.no_grad
def log_table(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], prefix: str, batch_idx: int
):
audio_enc, audio_mask, phoneme, phoneme_mask, voice_enc = batch
random_index = torch.randint(0, audio_enc.shape[0], (1,))
random_audio_mask = audio_mask[[random_index]]
random_audio_enc = audio_enc[[random_index]]
random_phoneme = phoneme[[random_index]]
random_phoneme_mask = phoneme_mask[[random_index]]
rondom_voice_enc = voice_enc[[random_index]]
gen_audio_enc = self.sample(
audio_enc=torch.zeros_like(random_audio_enc),
audio_mask=random_audio_mask,
phoneme=random_phoneme,
phoneme_mask=random_phoneme_mask,
voice_enc=rondom_voice_enc,
alpha=1.0,
)
span_mask = self.get_span_mask(random_audio_mask)
context = torch.where(
rearrange(span_mask, "b l -> b l ()"), 0, random_audio_enc
)
pred_audio_enc = self.solve(
context=context,
mask=random_audio_mask,
phoneme=random_phoneme,
phoneme_mask=random_phoneme_mask,
voice_enc=rondom_voice_enc
)
pred_audio_enc = torch.where(
rearrange(span_mask, "b l -> b l ()"),
pred_audio_enc,
random_audio_enc,
)
self.log_data(
random_audio_enc,
pred_audio_enc,
context,
gen_audio_enc,
random_phoneme,
random_audio_mask,
prefix,
batch_idx,
rondom_voice_enc
)
@torch.compiler.disable
def log_data(
self,
random_audio_enc: Tensor,
pred_audio_enc: Tensor,
cond_audio_enc: Tensor,
gen_audio_enc: Tensor,
random_phoneme: Tensor,
random_audio_mask: Tensor,
prefix: str,
batch_idx: int,
rondom_voice_enc: Tensor
):
data: list[tuple[np.ndarray, np.ndarray]] = []
mel_voco = get_voco("mel").to(self.device)
random_audio_len = get_voco(self.voco_type).decode_length(
int(random_audio_mask.sum().item())
)
voco = get_voco(self.voco_type).to(self.device)
for enc in [random_audio_enc, pred_audio_enc, cond_audio_enc, gen_audio_enc]:
try:
audio = voco.decode(enc)
audio = audio[:, :random_audio_len]
audio = audio.float()
mel_audio = torchaudio.functional.resample(
rearrange(audio, "() l c -> () c l"),
orig_freq=self.sampling_rate,
new_freq=mel_voco.sampling_rate,
)
mel_audio = rearrange(mel_audio, "() c l -> c l ()")
mel = mel_voco.encode(mel_audio).detach().cpu().numpy()
audio = audio.detach().cpu().numpy()
audio /= np.maximum(audio.max(axis=(1, 2)), 1)
audio_numpy = (audio[0] * np.iinfo(np.int16).max).astype(np.int16)
data.append((audio_numpy, rearrange(mel, "c h w -> c w h")))
except Exception as e:
self.debug_logger.debug(
f"Error occured while plotting\n{e}\n", exc_info=True
)
if isinstance(self.logger, TensorBoardLogger):
for (audio, _), name in zip(data, ["real", "pred", "cond", "gen"]):
self.logger.experiment.add_audio(
f"{prefix}/audio/{name}",
(audio / np.iinfo(np.int16).max).mean(axis=-1),
self.global_step,
sample_rate=self.sampling_rate,
)
mel_plot = plot_with_cmap(
list(rearrange(mel, "c w h -> (c w) h") for _, mel in data)
)
self.logger.experiment.add_image(
f"{prefix}/mel",
mel_plot,
self.global_step,
dataformats="HWC",
)
elif isinstance(self.logger, MLFlowLogger):
assert self.logger.run_id is not None
step = self.global_step
audio_paths = []
image_paths = []
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
tokenizer.padding_side = "right"
caption = tokenizer.decode(random_phoneme[0], skip_special_tokens=True)
with TemporaryDirectory() as temp_dir:
Path(temp_dir, prefix).mkdir()
for (audio, mel), name in zip(data, ["real", "pred", "cond", "gen"]):
audio_path = Path(
temp_dir, prefix, f"{step:07d}_{batch_idx:03d}_{name}.flac"
)
torchaudio.save(
audio_path,
torch.from_numpy(audio / np.iinfo(np.int16).max),
self.sampling_rate,
channels_first=False,
)
self.logger.experiment.log_artifact(
self.logger.run_id, audio_path, f"{prefix}/audio"
)
audio_paths.append(audio_path)
mel_plot = plot_with_cmap([mel_item for mel_item in mel])
image_path = Path(
temp_dir, prefix, f"{step:07d}_{batch_idx:03d}_{name}.png"
)
plt.imsave(image_path, mel_plot)
self.logger.experiment.log_artifact(
self.logger.run_id, image_path, f"{prefix}/image"
)
image_paths.append(image_path)
html = write_html(audio_paths, image_paths, caption)
self.logger.experiment.log_text(
self.logger.run_id,
html,
f"{prefix}/{step:07d}_{batch_idx:03d}.html",
)
def configure_optimizers(self):
match self.optim:
case "Adam":
optimizer = Adam(self.parameters(), lr=self.lr)
case "AdamW":
optimizer = AdamW(self.parameters(), lr=self.lr, weight_decay=1e-2)
case _:
raise ValueError(f"Unknown optimizer: {self.optim}")
match self.scheduler:
case "linear_warmup_decay":
warmup_scheduler = LinearLR(
optimizer, start_factor=1 / 5000, end_factor=1.0, total_iters=5000
)
decay_scheduler = LinearLR(
optimizer,
start_factor=1.0,
end_factor=0.0,
total_iters=self.max_steps,
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[5000],
)
case _:
raise ValueError(f"Unknown scheduler: {self.scheduler}")
return OptimizerLRSchedulerConfig(
optimizer=optimizer,
lr_scheduler=LRSchedulerConfigType(
scheduler=scheduler,
interval="step",
frequency=1,
reduce_on_plateau=False,
strict=True,
),
)
def on_save_checkpoint(self, checkpoint: dict[str, dict[str, Any]]):
for key in list(checkpoint["state_dict"]):
if key.startswith("t5"):
del checkpoint["state_dict"][key]
def on_load_checkpoint(self, checkpoint: dict[str, dict[str, Any]]):
for name, param in self.t5.named_parameters():
checkpoint["state_dict"][f"t5.{name}"] = param.data.clone()
checkpoint["state_dict"]["t5.encoder.embed_tokens.weight"] = (
self.t5.encoder.embed_tokens.weight.data.clone()
)