Created
January 16, 2025 22:11
-
-
Save lucasnewman/8a8e5d1791ad2d1547e8437d84795f37 to your computer and use it in GitHub Desktop.
mel_filterbank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def mel_filterbank( | |
n_freqs, | |
f_min, | |
f_max, | |
n_mels, | |
sample_rate, | |
norm=None, | |
mel_scale="htk" | |
): | |
def hz_to_mel(freq, mel_scale="htk"): | |
freq = torch.tensor(freq, dtype=torch.float32) | |
if mel_scale == "htk": | |
return 2595.0 * torch.log10(1.0 + freq / 700.0) | |
# slaney scale: linear below 1kHz, log above | |
f_min, f_sp = 0.0, 200.0 / 3 | |
mels = (freq - f_min) / f_sp | |
min_log_hz = 1000.0 | |
min_log_mel = (min_log_hz - f_min) / f_sp | |
logstep = torch.log(torch.tensor(6.4)) / 27.0 | |
if freq >= min_log_hz: | |
mels = min_log_mel + torch.log(freq / min_log_hz) / logstep | |
return mels | |
def mel_to_hz(mels, mel_scale="htk"): | |
if mel_scale == "htk": | |
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) | |
# slaney scale: linear below 1kHz, log above | |
f_min, f_sp = 0.0, 200.0 / 3 | |
freqs = f_min + f_sp * mels | |
min_log_hz = 1000.0 | |
min_log_mel = (min_log_hz - f_min) / f_sp | |
logstep = torch.log(torch.tensor(6.4)) / 27.0 | |
log_t = mels >= min_log_mel | |
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) | |
return freqs | |
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) | |
m_min = hz_to_mel(f_min, mel_scale) | |
m_max = hz_to_mel(f_max, mel_scale) | |
m_pts = torch.linspace(m_min, m_max, n_mels + 2) | |
f_pts = mel_to_hz(m_pts, mel_scale) | |
f_diff = f_pts[1:] - f_pts[:-1] | |
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) | |
down_slopes = (-slopes[:, :-2]) / f_diff[:-1] | |
up_slopes = slopes[:, 2:] / f_diff[1:] | |
filterbank = torch.max( | |
torch.zeros_like(down_slopes), torch.min(down_slopes, up_slopes) | |
) | |
if norm == "slaney": | |
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) | |
filterbank *= enorm.unsqueeze(0) | |
return filterbank.moveaxis(0, 1) | |
# test | |
sr = 24000 | |
n_fft = 1024 | |
n_mels = 100 | |
f_min = 0.0 | |
f_max = sr / 2 | |
norm = 'slaney' # or None | |
mel_scale = "htk" | |
torch_filterbank = mel_filterbank( | |
n_freqs=n_fft // 2 + 1, | |
f_min=f_min, | |
f_max=f_max, | |
n_mels=n_mels, | |
sample_rate=sr, | |
norm=norm, | |
mel_scale=mel_scale, | |
) | |
import librosa | |
librosa_filterbank = librosa.filters.mel( | |
sr=sr, | |
n_fft=n_fft, | |
n_mels=n_mels, | |
fmin=f_min, | |
fmax=f_max, | |
norm=norm, | |
htk=(mel_scale == "htk"), | |
) | |
librosa_filterbank_torch = torch.tensor(librosa_filterbank, dtype=torch.float32) | |
is_close = torch.allclose(torch_filterbank, librosa_filterbank_torch, atol=1e-4) | |
print(f"filterbank shapes match: {torch_filterbank.shape == librosa_filterbank_torch.shape}") | |
print(f"nearly equal: {is_close}") | |
if not is_close: | |
difference = torch.abs(torch_filterbank - librosa_filterbank_torch) | |
max_diff = torch.max(difference).item() | |
print(f"Maximum difference between filterbanks: {max_diff}") | |
print("Sample differences (first 5 filters):") | |
print(difference[:5, :10]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment