VTS/vts/model/module.py
2026-06-12 23:35:56 +09:00

447 lines
15 KiB
Python

import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
import numpy as np
import torch
import torchaudio
from einops import rearrange, repeat
from hydra.core.hydra_config import HydraConfig
from jaxtyping import Float
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.utilities.types import (
LRSchedulerConfigType,
OptimizerLRSchedulerConfig,
)
from matplotlib import pyplot as plt
from torch import Tensor
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torchdiffeq import odeint
from transformers import AutoTokenizer, T5EncoderModel
from vocos_custom import get_voco
from model.vts import VTS
from model.loss import masked_loss
from torchode.interface import solve_ivp
from utils.mask import min_span_mask, prob_mask_like
from utils.typing import AudioTensor, EncMaskTensor, EncTensor, LossTensor
from utils.utils import plot_with_cmap, write_html
class VTSModule(LightningModule):
def __init__(
self,
dim: int,
depth: int,
heads: int,
attn_dropout: float,
ff_dropout: float,
kernel_size: int,
voco_type: str,
max_audio_len: int,
optimizer: str = "Adam",
lr: float = 1e-4,
scheduler: str = "linear_warmup_decay",
use_torchode=True,
torchdiffeq_ode_method="midpoint",
torchode_method_klass="tsit5",
max_steps: int = 1000000,
text_repo_id: str = "google/flan-t5-base",
):
super().__init__()
self.save_hyperparameters()
self.voco_type = voco_type
voco = get_voco(voco_type)
self.sampling_rate = voco.sampling_rate
self.t5 = T5EncoderModel.from_pretrained(text_repo_id)
self.t5.eval()
for param in self.t5.parameters():
param.requires_grad_(False)
phoneme_dim = self.t5.config.d_model
self.vts = VTS(
audio_dim=voco.latent_dim,
phoneme_dim=phoneme_dim,
dim=dim,
depth=depth,
heads=heads,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
kernel_size=kernel_size,
)
self.mask_fracs = (0.7, 1.0)
self.min_span = 10
self.drop_prob = 0.1
self.max_audio_len = max_audio_len
self.use_torchode = use_torchode
self.torchode_method_klass = torchode_method_klass
self.steps = 64
self.sigma = 1e-5
self.method = torchdiffeq_ode_method
self.optim = optimizer
self.lr = lr
self.scheduler = scheduler
self.max_steps = max_steps
self.debug_logger = logging.getLogger("vts")
self.debug_logger.setLevel(logging.DEBUG)
try:
output_dir = HydraConfig.get().runtime.output_dir
except Exception:
output_dir = "outputs"
if not Path(output_dir).exists():
Path(output_dir).mkdir()
handler = logging.FileHandler(f"{output_dir}/vts.log")
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
self.debug_logger.addHandler(handler)
def load_state_dict(self, state_dict, strict=True, assign=False):
old_prefix = "audio" + "box."
if any(key.startswith(old_prefix) for key in state_dict):
state_dict = {
f"vts.{key[len(old_prefix):]}" if key.startswith(old_prefix) else key: value
for key, value in state_dict.items()
}
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def solve(
self,
context: EncTensor,
mask: EncMaskTensor,
phoneme: Tensor,
phoneme_mask: Tensor,
alpha=0.0,
) -> EncTensor:
with torch.autocast(device_type=self.device.type, enabled=False):
phoneme_emb = self.t5(
input_ids=phoneme, attention_mask=phoneme_mask
).last_hidden_state
def fn(t: Float[Tensor, "..."], y: Float[Tensor, "..."]):
out = self.vts.cfg(
w=y,
context=context,
times=t,
alpha=alpha,
mask=mask,
phoneme_emb=phoneme_emb,
phoneme_mask=phoneme_mask,
)
return out
y0 = torch.randn_like(context)
t = torch.linspace(0, 1, self.steps, device=self.device)
if self.use_torchode:
batch = context.shape[0]
t = repeat(t, "n -> b n", b=batch)
sol = solve_ivp(
torch.compile(fn, dynamic=False),
y0,
t,
method_class=self.torchode_method_klass,
)
sampled = sol.ys[-1]
else:
trajectory = odeint(
fn,
y0,
t,
atol=self.atol,
rtol=self.rtol,
method=self.method,
options=dict(step_size=1 / self.steps),
)
sampled = trajectory[-1]
return sampled
@torch.compiler.disable
@torch.no_grad
def get_span_mask(self, audio_mask: EncMaskTensor):
audio_lens = audio_mask.sum(dim=1).detach().cpu().numpy()
span_mask = pad_sequence(
[
torch.from_numpy(
min_span_mask(
int(audio_len),
fmin=self.mask_fracs[0],
fmax=self.mask_fracs[1],
min_span=self.min_span,
)
).to(self.device)
for audio_len in audio_lens
],
batch_first=True,
)
return F.pad(span_mask, (0, self.max_audio_len - span_mask.shape[1]))
def sample(
self,
audio_enc: AudioTensor,
audio_mask: EncMaskTensor,
phoneme: Tensor,
phoneme_mask: Tensor,
alpha=0.0,
):
span_mask = self.get_span_mask(audio_mask)
span_mask = torch.ones_like(span_mask)
audio_context = torch.where(rearrange(span_mask, "b l -> b l ()"), 0, audio_enc)
sampled_audio_enc = self.solve(
audio_context, audio_mask, phoneme, phoneme_mask, alpha=alpha
)
return sampled_audio_enc
def forward(
self,
audio_enc: AudioTensor,
audio_mask: EncMaskTensor,
phoneme: Tensor,
phoneme_mask: Tensor,
) -> LossTensor:
batch = audio_enc.shape[0]
with torch.no_grad():
span_mask = self.get_span_mask(audio_mask)
with torch.autocast(device_type=self.device.type, enabled=False):
phoneme_emb = self.t5(
input_ids=phoneme, attention_mask=phoneme_mask
).last_hidden_state
audio_x0 = torch.randn_like(audio_enc)
times = torch.rand((batch,), dtype=audio_enc.dtype, device=self.device)
t = rearrange(times, "b -> b () ()")
w = (1 - (1 - self.sigma) * t) * audio_x0 + t * audio_enc
cond_drop_mask = prob_mask_like((batch, 1), self.drop_prob, self.device)
audio_cond_mask = span_mask | cond_drop_mask
audio_context = torch.where(
rearrange(audio_cond_mask, "b l -> b l ()"), 0, audio_enc
)
phon_drop_mask = prob_mask_like((batch,), self.drop_prob, self.device)
phoneme_emb = torch.where(
rearrange(phon_drop_mask, "b -> b () ()"), 0, phoneme_emb
)
pred_audio_flow = self.vts(
w=w,
times=times,
audio_mask=audio_mask,
context=audio_context,
phoneme_emb=phoneme_emb,
phoneme_mask=phoneme_mask,
)
target_audio_flow = audio_enc - (1 - self.sigma) * audio_x0
loss = masked_loss(pred_audio_flow, target_audio_flow, audio_cond_mask, "mse")
return loss
@torch.compiler.disable
def log_loss(self, id: str, loss: Tensor | float, train: bool):
if train:
self.log(id, loss, on_step=True, on_epoch=False)
else:
self.log(id, loss, on_step=False, on_epoch=True, sync_dist=True)
def single_step(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor], prefix: str
) -> LossTensor:
audio, audio_mask, phoneme, phoneme_mask = batch
loss = self(audio, audio_mask, phoneme, phoneme_mask)
train = prefix == "train"
self.log_loss(f"{prefix}/loss", loss, train)
return loss
def training_step(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor], batch_idx: int
):
return self.single_step(batch, "train")
def validation_step(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor], batch_idx: int
):
self.single_step(batch, "val")
if batch_idx < 5:
self.log_table(batch, "val", batch_idx)
def test_step(self, batch: tuple[Tensor, Tensor, Tensor, Tensor], batch_idx: int):
self.single_step(batch, "test")
if batch_idx < 5:
self.log_table(batch, "test", batch_idx)
@rank_zero_only
@torch.no_grad
def log_table(
self, batch: tuple[Tensor, Tensor, Tensor, Tensor], prefix: str, batch_idx: int
):
audio_enc, audio_mask, phoneme, phoneme_mask = batch
random_index = torch.randint(0, audio_enc.shape[0], (1,))
random_audio_mask = audio_mask[[random_index]]
random_audio_enc = audio_enc[[random_index]]
random_phoneme = phoneme[[random_index]]
random_phoneme_mask = phoneme_mask[[random_index]]
gen_audio_enc = self.sample(
audio_enc=torch.zeros_like(random_audio_enc),
audio_mask=random_audio_mask,
phoneme=random_phoneme,
phoneme_mask=random_phoneme_mask,
alpha=1.0,
)
span_mask = self.get_span_mask(random_audio_mask)
context = torch.where(
rearrange(span_mask, "b l -> b l ()"), 0, random_audio_enc
)
pred_audio_enc = self.solve(
context=context,
mask=random_audio_mask,
phoneme=random_phoneme,
phoneme_mask=random_phoneme_mask,
)
pred_audio_enc = torch.where(
rearrange(span_mask, "b l -> b l ()"),
pred_audio_enc,
random_audio_enc,
)
self.log_data(
random_audio_enc,
pred_audio_enc,
context,
gen_audio_enc,
random_phoneme,
random_audio_mask,
prefix,
batch_idx,
)
@torch.compiler.disable
def log_data(
self,
random_audio_enc: Tensor,
pred_audio_enc: Tensor,
cond_audio_enc: Tensor,
gen_audio_enc: Tensor,
random_phoneme: Tensor,
random_audio_mask: Tensor,
prefix: str,
batch_idx: int,
):
data: list[tuple[np.ndarray, np.ndarray]] = []
mel_voco = get_voco("mel").to(self.device)
random_audio_len = get_voco(self.voco_type).decode_length(
int(random_audio_mask.sum().item())
)
voco = get_voco(self.voco_type).to(self.device)
for enc in [random_audio_enc, pred_audio_enc, cond_audio_enc, gen_audio_enc]:
try:
audio = voco.decode(enc)
audio = audio[:, :random_audio_len]
audio = audio.float()
mel_audio = torchaudio.functional.resample(
rearrange(audio, "() l c -> () c l"),
orig_freq=self.sampling_rate,
new_freq=mel_voco.sampling_rate,
)
mel_audio = rearrange(mel_audio, "() c l -> c l ()")
mel = mel_voco.encode(mel_audio).detach().cpu().numpy()
audio = audio.detach().cpu().numpy()
audio /= np.maximum(audio.max(axis=(1, 2)), 1)
audio_numpy = (audio[0] * np.iinfo(np.int16).max).astype(np.int16)
data.append((audio_numpy, rearrange(mel, "c h w -> c w h")))
except Exception as e:
self.debug_logger.debug(
f"Error occured while plotting\n{e}\n", exc_info=True
)
if isinstance(self.logger, TensorBoardLogger):
for (audio, _), name in zip(data, ["real", "pred", "cond", "gen"]):
self.logger.experiment.add_audio(
f"{prefix}/audio/{name}",
(audio / np.iinfo(np.int16).max).mean(axis=-1),
self.global_step,
sample_rate=self.sampling_rate,
)
mel_plot = plot_with_cmap(
list(rearrange(mel, "c w h -> (c w) h") for _, mel in data)
)
self.logger.experiment.add_image(
f"{prefix}/mel",
mel_plot,
self.global_step,
dataformats="HWC",
)
def configure_optimizers(self):
match self.optim:
case "Adam":
optimizer = Adam(self.parameters(), lr=self.lr)
case "AdamW":
optimizer = AdamW(self.parameters(), lr=self.lr, weight_decay=1e-2)
case _:
raise ValueError(f"Unknown optimizer: {self.optim}")
match self.scheduler:
case "linear_warmup_decay":
warmup_scheduler = LinearLR(
optimizer, start_factor=1 / 5000, end_factor=1.0, total_iters=5000
)
decay_scheduler = LinearLR(
optimizer,
start_factor=1.0,
end_factor=0.0,
total_iters=self.max_steps,
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[5000],
)
case _:
raise ValueError(f"Unknown scheduler: {self.scheduler}")
return OptimizerLRSchedulerConfig(
optimizer=optimizer,
lr_scheduler=LRSchedulerConfigType(
scheduler=scheduler,
interval="step",
frequency=1,
reduce_on_plateau=False,
strict=True,
),
)
def on_save_checkpoint(self, checkpoint: dict[str, dict[str, Any]]):
for key in list(checkpoint["state_dict"]):
if key.startswith("t5"):
del checkpoint["state_dict"][key]
def on_load_checkpoint(self, checkpoint: dict[str, dict[str, Any]]):
for name, param in self.t5.named_parameters():
checkpoint["state_dict"][f"t5.{name}"] = param.data.clone()
checkpoint["state_dict"]["t5.encoder.embed_tokens.weight"] = (
self.t5.encoder.embed_tokens.weight.data.clone()
)