VTS/vts/utils/typing.py
2026-06-12 23:35:56 +09:00

11 lines
408 B
Python

from jaxtyping import Bool, Float, Int
from torch import Tensor
AudioTensor = Float[Tensor, "batch audio audio_channel"]
AudioMaskTensor = Bool[Tensor, "batch audio"]
EncTensor = Float[Tensor, "batch codec channel"]
EncMaskTensor = Bool[Tensor, "batch codec"]
LengthTensor = Int[Tensor, "batch"]
LossTensor = Float[Tensor, ""]
TimeTensor = Float[Tensor, "batch"]
Batch = tuple[AudioTensor, AudioMaskTensor]