mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Whisper audio helpers (mel filters in tinygrad) (#13478)
* add whisper audio helpers for stft/mel/resample * cleanup * add whisper stft test * make only stft test explicitly depend on librosa * extract sinc_window_kernel * dehardcode device * use same device argument * simplify * type annotate * ruff format audio_helpers.py * ruff format test_whisper.py * add WHISPER_NEW_STFT * rename * undo ruff format changes * use new stft and mel for whisper * remove stft test that depends on librosa * remove whitespace * add Tensor.log10 with test\test_ops.py::TestOps::test_log10 * use Tensor.log10 * fix lint * future: remove unused STFT class * future: remove resample code since it isn't used (yet) * match openai with pad_mode="reflect" * pad_to * future: cut resample leftovers * cleanup * add mel tests * future: cut stft * future: cut non-mel prep_audio changes * reduce diff * move audio_helpers.py to examples * reduce whitespace * fix imports * reduce whitespace --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
dc82856084
commit
26f8b12e01
3 changed files with 104 additions and 2 deletions
79
examples/audio_helpers.py
Normal file
79
examples/audio_helpers.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from typing import Optional
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
import math
|
||||
|
||||
# rewritten from numpy
|
||||
def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor:
|
||||
val = 1.0 / (n * d)
|
||||
N = n // 2 + 1
|
||||
results = Tensor.arange(N, device=device)
|
||||
return results * val
|
||||
|
||||
# just like in librosa
|
||||
def fft_frequencies(sr: float, n_fft: int) -> Tensor:
|
||||
return rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
def hz_to_mel(freq: Tensor) -> Tensor:
|
||||
# linear part
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
mels = (freq - f_min) / f_sp
|
||||
|
||||
# log-scale part
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
mask = freq >= min_log_hz
|
||||
return mask.where(((min_log_hz - f_min) / f_sp) + (freq / min_log_hz).log() / (math.log(6.4) / 27.0), mels)
|
||||
|
||||
def mel_to_hz(mels: Tensor) -> Tensor:
|
||||
# linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
log_t = mels >= min_log_mel
|
||||
freqs = log_t.where(min_log_hz * ((logstep * (mels - min_log_mel)).exp()), freqs)
|
||||
return freqs
|
||||
|
||||
def mel_frequencies(n_mels: int = 128, *, fmin: float = 0.0, fmax: float = 11025.0) -> Tensor:
|
||||
# center freqs of mel bands - uniformly spaced between limits
|
||||
min_max_mel = hz_to_mel(Tensor([fmin, fmax]))
|
||||
|
||||
mels = Tensor.linspace(min_max_mel[0], min_max_mel[1], n_mels)
|
||||
hz = mel_to_hz(mels)
|
||||
return hz
|
||||
|
||||
def mel(
|
||||
*,
|
||||
sr: float,
|
||||
n_fft: int,
|
||||
n_mels: int = 128,
|
||||
fmin: float = 0.0,
|
||||
fmax: Optional[float] = None,
|
||||
dtype: DTypeLike = dtypes.default_float,
|
||||
) -> Tensor:
|
||||
if fmax is None:
|
||||
fmax = float(sr) / 2
|
||||
|
||||
n_mels = int(n_mels)
|
||||
|
||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) # center freqs of each FFT bin
|
||||
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax) # center freqs of mel bands
|
||||
|
||||
fdiff = mel_f[1:] - mel_f[:-1]
|
||||
ramps = mel_f[None].T.expand(-1, fftfreqs.shape[-1]) - fftfreqs
|
||||
|
||||
lower = -ramps[:n_mels] / fdiff[:n_mels][None].T
|
||||
upper = ramps[2 : n_mels + 2] / fdiff[1 : n_mels + 1][None].T
|
||||
weights = lower.minimum(upper).maximum(0)
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, None]
|
||||
|
||||
return weights
|
||||
|
|
@ -7,6 +7,7 @@ from tinygrad import Tensor, TinyJit, Variable, nn, dtypes
|
|||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
|
||||
from examples.audio_helpers import mel
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
|
|
@ -159,7 +160,7 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
|
|||
|
||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
mel_spec = mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS).numpy() @ magnitudes
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max((1,2), keepdims=True) - 8.0)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
from examples.audio_helpers import mel
|
||||
import examples.mlperf.metrics as metrics
|
||||
from tinygrad.helpers import fetch
|
||||
from test.helpers import slow
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
|
||||
# Audio generated with the command on MacOS:
|
||||
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
|
||||
|
|
@ -130,5 +132,25 @@ class TestWhisper(unittest.TestCase):
|
|||
reference = TRANSCRIPTION_3
|
||||
self.assertWER(reference[:len(reference)//2], reference, 0.524)
|
||||
|
||||
def test_mel_filters(self):
|
||||
# reference = librosa.filters.mel(sr=16000, n_fft=16, n_mels=16)
|
||||
reference = Tensor([[-0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0021111054811626673, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.003133024089038372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0017568661132827401, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0009823603322729468, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0007768510840833187, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0010490329004824162, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0011341988574713469, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.000231665835599415, 0.0006950111710466444, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.00040073052514344454, 0.0005822855746373534, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00033081238507293165, 0.0006097797304391861, 0.0]])
|
||||
np.testing.assert_allclose(mel(sr=16000, n_fft=16, n_mels=16, dtype=dtypes.float32).numpy(), reference.numpy(), atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue