VTS/vts/utils/utils.py
2026-06-12 23:35:56 +09:00

448 lines
No EOL
15 KiB
Python

import base64
import json
import subprocess
from pathlib import Path
import torchaudio
import librosa
import random
import numpy as np
import torch
from einops import rearrange
from matplotlib import pyplot as plt
from torch import Tensor
from librosa import filters
from torch.nn import functional as F
from utils.mask import mask_from_lengths
from scipy.ndimage import median_filter
from scipy.interpolate import interp1d
from utils.typing import EncTensor, LengthTensor
from scipy.signal import medfilt
plt.switch_backend("agg")
def pad_sequence(
sequences: list[Tensor], batch_first: bool = False, padding_value: int = 0
) -> Tensor:
"""
Pad a list of variable length Tensors with zero padding to the right.
Return a Tensor of shape (batch, max_time, channel) if batch_first is True,
else (max_time, batch, channel).
The original pad_sequence function from PyTorch errors when compiling.
Args:
sequences: List of variable length Tensors.
batch_first: If True, return Tensor of shape (batch, max_time, channel).
padding_value: Value to pad with.
Returns:
Padded Tensor.
"""
max_len = max([seq.size(0) for seq in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + sequences[0].size()[1:]
else:
out_dims = (max_len, len(sequences)) + sequences[0].size()[1:]
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
def plot_with_cmap(mels: list[np.ndarray], sharex: bool = True):
fig, axes = plt.subplots(len(mels), 1, figsize=(20, 8), sharex=sharex)
if len(mels) == 1:
axes = np.array([axes])
im = None
for i, mel in enumerate(mels):
im = axes[i].imshow(mel, aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=axes.ravel().tolist())
fig.canvas.draw()
plt.close(fig)
return np.array(
fig.canvas.buffer_rgba() # pyright: ignore [reportAttributeAccessIssue]
)
def normalize_audio(
audio_enc: EncTensor, audio_lens: LengthTensor
) -> tuple[EncTensor, Tensor, Tensor]:
"""
Normalize audio encodings to have zero mean and unit variance.
Each audio encoding is 2-dimensional and has zero padding to the right.
Return normalized audio encodings, mean, and standard deviation.
Args:
audio_enc: Audio encodings. Shape: (batch, time, channel).
audio_lens: Lengths of audio encodings. Shape: (batch,).
Returns:
audio_enc: Normalized audio encodings. Shape: (batch, time, channel).
audio_mean: Mean of audio encodings. Shape: (batch,).
audio_std: Standard deviation of audio encodings. Shape: (batch,).
"""
audio_mask = mask_from_lengths(audio_lens, audio_enc.shape[1])
# audio_mean = (audio_enc.mean(dim=2) * audio_mask).sum(dim=1) / audio_lens
# audio_sq_mean = ((audio_enc**2).mean(dim=2) * audio_mask).sum(dim=1) / audio_lens
# nelem = audio_lens * audio_enc.shape[2]
# bessel_correction = nelem / (nelem - 1)
# audio_std = torch.sqrt((audio_sq_mean - audio_mean**2) * bessel_correction)
batch_size = audio_enc.shape[0]
audio_mean = torch.full((batch_size,), -1.430645).to(audio_enc.device)
audio_std = torch.full((batch_size,), 2.1208718).to(audio_enc.device)
audio_mean = rearrange(audio_mean, "b -> b () ()")
audio_std = rearrange(audio_std, "b -> b () ()")
normalized_audio_enc = (
(audio_enc - audio_mean)
/ (audio_std + 1e-5)
* rearrange(audio_mask, "b l -> b l ()")
)
return normalized_audio_enc, audio_mean, audio_std
def write_html(audio_paths: list[Path], image_paths: list[Path], description: str):
html = f"""
<html>
<head>
<title>Audio and Mel Preview</title>
<!-- Lightbox2 CSS -->
<link href="https://cdnjs.cloudflare.com/ajax/libs/lightbox2/2.11.3/css/lightbox.min.css" rel="stylesheet" />
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f9f9f9;
margin: 0;
padding: 0;
}}
.container {{
/* Removed max-width to use full screen width */
margin: 0 auto;
padding: 20px;
}}
.description {{
background-color: #fff;
padding: 20px;
margin-bottom: 20px;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
max-width: 1000px;
margin-left: auto;
margin-right: auto;
}}
.grid {{
display: grid;
grid-template-columns: repeat(2, 1fr); /* Set to 2 columns */
grid-gap: 20px;
}}
.card {{
background-color: #fff;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
padding: 20px;
text-align: center;
}}
.card h3 {{
margin-top: 0;
text-transform: capitalize;
}}
audio {{
width: 100%;
margin: 10px 0;
}}
img {{
width: 100%;
height: auto;
border-radius: 5px;
cursor: pointer;
transition: transform 0.2s;
}}
img:hover {{
transform: scale(1.02);
}}
@media (max-width: 800px) {{
.grid {{
grid-template-columns: 1fr; /* Stack cards on small screens */
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="description">
<h2>Description</h2>
<p>{description}</p>
</div>
<div class="grid">
"""
names = ["real", "pred", "cond", "gen"]
for row_name, audio_path, image_path in zip(names, audio_paths, image_paths):
with open(audio_path, "rb") as f:
audio_base64 = base64.b64encode(f.read()).decode("utf-8")
with open(image_path, "rb") as f:
image_base64 = base64.b64encode(f.read()).decode("utf-8")
html += f"""
<div class="card">
<h3>{row_name}</h3>
<audio controls>
<source src="data:audio/flac;base64,{audio_base64}" type="audio/flac">
Your browser does not support the audio element.
</audio>
<a href="data:image/png;base64,{image_base64}" data-lightbox="mel-spectrograms" data-title="{row_name} Mel Spectrogram">
<img src="data:image/png;base64,{image_base64}" alt="{row_name} Mel Spectrogram">
</a>
</div>
"""
html += """
</div>
</div>
<!-- Lightbox2 JS -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/lightbox2/2.11.3/js/lightbox-plus-jquery.min.js"></script>
</body>
</html>
"""
return html
def get_audio_info(file: str) -> tuple:
command = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a", # Select only audio streams
"-show_entries",
"stream=channels", # Get channel count from the stream
"-show_entries",
"format=duration", # Get the duration from the format section
"-of",
"json",
file,
]
# Run ffprobe command
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Parse the JSON output
output = json.loads(result.stdout)
# Extract channels and duration
channels = int(output["streams"][0]["channels"])
duration = float(output["format"]["duration"])
return channels, duration
def extract_audio_segment(
file: str, start_time: float, dur: float, sr: int, num_channels: int
) -> np.ndarray:
# Define the ffmpeg command to output raw PCM data
command = [
"ffmpeg",
"-ss",
str(start_time),
"-i",
file,
"-t",
str(dur),
"-f",
"s16le",
"-ac",
str(num_channels),
"-ar",
str(sr),
"-loglevel",
"error",
"pipe:1",
]
# Run the command and capture the output
process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Convert the raw PCM data to a NumPy array
wav_data = np.frombuffer(process.stdout, dtype=np.int16).copy()
# Reshape based on the number of channels
wav_data = wav_data.reshape(-1, num_channels)
return wav_data
def blur_latent(latent: torch.Tensor, is_augmentation: bool = False):
def noise_audio(latent, corrupt=0.6):
c = 1.0 - corrupt
noised_enc = (latent * c) + torch.randn_like(latent) * (1 - (1 - 1e-4) * c)
return noised_enc
def shift_tensor(tensor: torch.Tensor, shift: int):
"""
shift > 0: 뒤로 밀기 (앞에 zero padding 추가)
shift < 0: 앞으로 밀기 (뒤에 zero padding 추가)
shift = 0: 변화 없음
"""
tensor = tensor.unsqueeze(dim=0)
B, T, D = tensor.shape
if shift == 0:
return tensor
if shift > 0:
pad = (0, 0, shift, 0)
padded = F.pad(tensor, pad, mode='constant', value=0)
return padded[:, :T, :] # 앞에 pad 추가된 만큼 자름
else:
shift = -shift # 뒤에 pad 추가, 앞에서 자름
pad = (0, 0, 0, shift)
padded = F.pad(tensor, pad, mode='constant', value=0)
return padded[:, shift:, :]
voice_cond = torch.from_numpy(latent.copy()).unsqueeze(dim=0)
voice_cond = voice_cond.permute(0, 2, 1) # (B, 64, 48)
voice_cond = F.avg_pool1d(voice_cond, kernel_size=2, stride=2) # (B, 64, 24)
voice_cond = F.interpolate(voice_cond, scale_factor=2, mode='nearest') # (B, 64, 48)
voice_cond = voice_cond.permute(0, 2, 1).squeeze()
voice_cond = noise_audio(voice_cond)
if is_augmentation:
shift = random.randint(-3, 3)
if random.random()<0.5:
shift = 0
voice_cond = shift_tensor(voice_cond, shift=shift)
return voice_cond
import torch
import torchaudio
import numpy as np
from scipy.signal import medfilt
from torchaudio.functional import spectral_centroid
N_CHROMA = 24
RADIX2_EXP = 14
WIN_LENGTH = 2 ** RADIX2_EXP
SAMPLE_RATE = 44100
TARGET_FRAMERATE = 21.5
HOP_LENGTH = int(SAMPLE_RATE / TARGET_FRAMERATE)
N_FFT = 2048
# Chroma filterbank
chroma_filterbank = torch.from_numpy(
filters.chroma(sr=SAMPLE_RATE, n_fft=N_FFT, tuning=0, n_chroma=N_CHROMA)
).float()
# Spectrogram transform
spec_transform = torchaudio.transforms.Spectrogram(
n_fft=N_FFT,
hop_length=HOP_LENGTH,
power=2,
center=True,
pad=0,
normalized=True,
)
def min_max_normalize(tensor: torch.Tensor) -> torch.Tensor:
return (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-6)
def compute_centroid(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
spec = spec_transform(waveform)
freqs = torch.linspace(0, sample_rate // 2, spec.shape[1], device=spec.device)
centroid = (spec * freqs[None, :, None]).sum(dim=1) / (spec.sum(dim=1) + 1e-6)
midi_like = 69 + 12 * torch.log2(centroid / 440.0 + 1e-6)
return (midi_like / 127.0).clamp(0, 1)
def add_noise(tensor: torch.Tensor, std: float = 0.005) -> torch.Tensor:
noise = torch.randn_like(tensor) * std
return (tensor + noise).clamp(0.0, 1.0) # 정규화 유지
def get_dynamic(waveform: torch.Tensor, max_len: int) -> torch.Tensor:
if waveform.ndim == 3:
waveform = waveform.mean(dim=1)
elif waveform.ndim == 2 and waveform.size(0) == 2:
waveform = waveform.mean(dim=0)
spec = spec_transform(waveform) # [B, F, T]
chroma = torch.einsum('cf,...ft->...ct', chroma_filterbank, spec)
chroma_normed = torch.nn.functional.normalize(chroma, p=float('inf'), dim=-2, eps=1e-6)
# max_indices = chroma_normed.argmax(dim=-2, keepdim=True) # (B, 1, T)
# mask = torch.zeros_like(chroma_normed)
# mask.scatter_(-2, max_indices, 1.0) # set 1 where chroma bin is max
# chroma_maxonly = chroma_normed * mask
chroma_indices = chroma_normed.argmax(dim=-2, keepdim=True) # (B, 1, T)
chroma_maxonly = chroma_indices.expand(-1, 4, -1) # (B, T, 1)
# mask = (chroma_normed >= 0.8).float()
# chroma_maxonly = chroma_normed * mask
# RMS 계산 (PyTorch 버전)
frame_size = HOP_LENGTH
waveform_mono = waveform.mean(dim=0) if waveform.dim() == 2 else waveform
rms_frames = waveform_mono.unfold(0, frame_size, frame_size)
rms = torch.sqrt((rms_frames ** 2).mean(dim=1)) # [T_rms]
rms = rms.clamp(min=1e-8)
# RMS downsample to chroma time resolution
T_chroma = chroma.size(-1)
rms_down = torch.nn.functional.interpolate(rms.unsqueeze(0).unsqueeze(0), size=T_chroma, mode='linear', align_corners=False)
expanded_rms = rms_down.squeeze(0).expand(4, -1).unsqueeze(0) # [1, 4, T]
expanded_rms = min_max_normalize(expanded_rms)
centroid = compute_centroid(waveform, SAMPLE_RATE)[0].unsqueeze(0).expand(4, -1).unsqueeze(0) # [1, 4, T]
# add noise
if random.random() < 0.5:
expanded_rms = add_noise(expanded_rms)
if random.random() < 0.5:
centroid = add_noise(centroid)
combined = torch.cat((centroid, expanded_rms, chroma_maxonly), dim=-2) # [1, D, T]
combined = combined.permute(0, 2, 1) # [1, T, D]
if combined.shape[1] < max_len:
pad_len = max_len - combined.shape[1]
pad = torch.zeros((1, pad_len, combined.shape[2]), device=combined.device)
combined = torch.cat([combined, pad], dim=1)
else:
combined = combined[:, :max_len, :]
return combined.squeeze()
def span_mask_strided(tensor, span_len=6, stride=16):
"""
일정한 간격(stride)마다 일정 길이(span_len)로 마스크를 적용하는 함수
Args:
tensor (Tensor): 입력 텐서, shape (1, T, D)
span_len (int): 마스크될 구간 길이
stride (int): 마스크가 시작되는 위치 간격
Returns:
masked_tensor: 마스킹이 적용된 텐서
"""
B, T, D = tensor.shape
assert B == 1, "이 코드는 배치 크기 1 기준입니다."
mask = torch.zeros(T, dtype=torch.bool)
for start in range(0, T, stride):
end = min(start + span_len, T)
mask[start:end] = True
# [1, T, D]로 broadcast
mask = mask[None, :, None].expand(B, T, D)
masked_tensor = tensor.masked_fill(mask, 0.0)
return masked_tensor