Skip to content

Instantly share code, notes, and snippets.

@eldrin
Last active August 31, 2024 18:18
Show Gist options
  • Save eldrin/0f8f8e5594f5dfeff5b9261e43242436 to your computer and use it in GitHub Desktop.
Save eldrin/0f8f8e5594f5dfeff5b9261e43242436 to your computer and use it in GitHub Desktop.
Quick digging in what makes the mel-spectrum discrepancy between torch audio and librosa
import math
from typing import Callable, Optional
from warnings import warn
import torch
from torch import Tensor
from torchaudio import functional as F
from torchaudio.compliance import kaldi
class MyMelScale(torch.nn.Module):
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None,
fb_norm: Optional[str] = None) -> None:
super(MyMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.fb_norm = fb_norm
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
if fb_norm == 'librosa_slaney':
mel_kernel_librosa = librosa.filters.mel(
sr, int(2 * (n_stft - 1)), n_mels=self.n_mels,
fmin=self.f_min, fmax=self.f_max, norm='slaney'
)
fb = torch.Tensor(mel_kernel_librosa.T)
else:
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate,
norm=self.fb_norm
)
self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns:
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
if fb_norm == 'librosa_slaney':
mel_kernel_librosa = librosa.filters.mel(
sr, int(2 * (spectram.size(1) - 1)), n_mels=self.n_mels,
fmin=self.f_min, fmax=self.f_max, norm='slaney'
)
fb = torch.Tensor(mel_kernel_librosa.T)
else:
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
spectram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate,
norm=self.fb_norm
)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram
class MyMelSpectrogram(torch.nn.Module):
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
Sources
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (channel, n_mels, time)
"""
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self,
sample_rate: int = 16000,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
f_min: float = 0.,
f_max: Optional[float] = None,
pad: int = 0,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
fb_norm: Optional[str] = None,
wkwargs: Optional[dict] = None) -> None:
super(MyMelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
self.pad = pad
self.power = power
self.normalized = normalized
self.n_mels = n_mels # number of mel frequency bins
self.f_max = f_max
self.f_min = f_min
self.fb_norm = fb_norm
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs)
self.mel_scale = MyMelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max,
self.n_fft // 2 + 1, self.fb_norm)
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
return mel_specgram
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
from torchaudio import functional as Fa
mel_kernel_torchaudio = Fa.create_fb_matrix(
int(n_fft // 2 + 1),
n_mels=128,
f_min=0.,
f_max=sr/2.,
sample_rate=sr,
norm=None
)
mel_kernel_torchaudio_slaney = Fa.create_fb_matrix(
int(n_fft // 2 + 1),
n_mels=128,
f_min=0.,
f_max=sr/2.,
sample_rate=sr,
norm='slaney'
)
mel_kernel_librosa_htk = librosa.filters.mel(
sr,
n_fft,
n_mels=128,
fmin=0.,
fmax=sr/2.,
norm='slaney',
htk=True,
)
mel_kernel_librosa_slaney = librosa.filters.mel(
sr,
n_fft,
n_mels=128,
fmin=0.,
fmax=sr/2.,
norm='slaney',
htk=False,
)
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('mel-Kernel')
axs[0][0].set_title('torchaudio[None]')
axs[0][0].imshow(mel_kernel_torchaudio, aspect='auto')
axs[0][0].set_ylabel('frequency bin')
axs[0][0].set_xlabel('mel bin')
axs[0][1].set_title('torchaudio[slaney]')
axs[0][1].imshow(mel_kernel_torchaudio_slaney, aspect='auto')
axs[0][1].set_xlabel('mel bin')
axs[1][0].set_title('librosa[htk + slaney]')
axs[1][0].imshow(mel_kernel_librosa_htk.T, aspect='auto')
axs[1][0].set_xlabel('mel bin')
axs[1][1].set_title('librosa[audiotory_toolbox + slaney]')
axs[1][1].imshow(mel_kernel_librosa_slaney.T, aspect='auto')
axs[1][1].set_xlabel('mel bin')
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
from torchaudio import functional as Fa
# assume MyMelSpectrogram is declared somewhere
from my_melspec import MyMelSpectrogram
# some variables
fn = '4538556.clip.mp3'
sr = 6000 # sampling rate
# librosa default
n_fft = 2048
win_len = None
hop_len = 512
# # with torchaudio.load
# waveform, sample_rate = torchaudio.load(
# '4538556.clip.mp3',
# normalization=True
# )
# waveform = waveform.mean(0)
# With librosa.load
waveform, sample_rate = librosa.load(fn, sr=sr)
waveform = torch.Tensor(waveform)
melspecs = {}
for fb_norm in [None, 'slaney', 'librosa_slaney']:
# instantiate mel-spectrogram transform
torch_melspectrogram = MyMelSpectrogram(
sample_rate,
n_fft=n_fft,
win_length=win_len,
hop_length=hop_len,
fb_norm=fb_norm
)
# compute
X1 = torch_melspectrogram(waveform)
X2 = librosa.feature.melspectrogram(waveform.numpy(),
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len)
# for plot
melspecs[fb_norm] = (X1, X2)
# compute error
mse = ((X1 - X2)**2).mean()
# log error
print(f'Mean Squared Error[fb:{fb_norm}]:\t{mse:.4f}')
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
axs[0].set_title('torchaudio / filter bank [None]')
axs[0].set_ylabel('mel bin')
axs[0].set_xlabel('frame')
axs[0].imshow(librosa.power_to_db(melspecs[None][0]), aspect='auto')
axs[1].set_title('torchaudio / filter bank [slaney]')
axs[1].set_xlabel('frame')
axs[1].imshow(librosa.power_to_db(melspecs['slaney'][0]), aspect='auto')
axs[2].set_title('torchaudio / filter bank [librosa_slaney]')
axs[2].set_xlabel('frame')
axs[2].imshow(librosa.power_to_db(melspecs['librosa_slaney'][0]), aspect='auto')
axs[3].set_title('librosa')
axs[3].set_xlabel('frame')
axs[3].imshow(librosa.power_to_db(melspecs[None][1]), aspect='auto')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment