Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created January 16, 2025 22:11
Show Gist options
  • Save lucasnewman/8a8e5d1791ad2d1547e8437d84795f37 to your computer and use it in GitHub Desktop.
Save lucasnewman/8a8e5d1791ad2d1547e8437d84795f37 to your computer and use it in GitHub Desktop.
mel_filterbank.py
import torch
def mel_filterbank(
n_freqs,
f_min,
f_max,
n_mels,
sample_rate,
norm=None,
mel_scale="htk"
):
def hz_to_mel(freq, mel_scale="htk"):
freq = torch.tensor(freq, dtype=torch.float32)
if mel_scale == "htk":
return 2595.0 * torch.log10(1.0 + freq / 700.0)
# slaney scale: linear below 1kHz, log above
f_min, f_sp = 0.0, 200.0 / 3
mels = (freq - f_min) / f_sp
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = torch.log(torch.tensor(6.4)) / 27.0
if freq >= min_log_hz:
mels = min_log_mel + torch.log(freq / min_log_hz) / logstep
return mels
def mel_to_hz(mels, mel_scale="htk"):
if mel_scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
# slaney scale: linear below 1kHz, log above
f_min, f_sp = 0.0, 200.0 / 3
freqs = f_min + f_sp * mels
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = torch.log(torch.tensor(6.4)) / 27.0
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
return freqs
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
m_min = hz_to_mel(f_min, mel_scale)
m_max = hz_to_mel(f_max, mel_scale)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = mel_to_hz(m_pts, mel_scale)
f_diff = f_pts[1:] - f_pts[:-1]
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
up_slopes = slopes[:, 2:] / f_diff[1:]
filterbank = torch.max(
torch.zeros_like(down_slopes), torch.min(down_slopes, up_slopes)
)
if norm == "slaney":
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
filterbank *= enorm.unsqueeze(0)
return filterbank.moveaxis(0, 1)
# test
sr = 24000
n_fft = 1024
n_mels = 100
f_min = 0.0
f_max = sr / 2
norm = 'slaney' # or None
mel_scale = "htk"
torch_filterbank = mel_filterbank(
n_freqs=n_fft // 2 + 1,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
sample_rate=sr,
norm=norm,
mel_scale=mel_scale,
)
import librosa
librosa_filterbank = librosa.filters.mel(
sr=sr,
n_fft=n_fft,
n_mels=n_mels,
fmin=f_min,
fmax=f_max,
norm=norm,
htk=(mel_scale == "htk"),
)
librosa_filterbank_torch = torch.tensor(librosa_filterbank, dtype=torch.float32)
is_close = torch.allclose(torch_filterbank, librosa_filterbank_torch, atol=1e-4)
print(f"filterbank shapes match: {torch_filterbank.shape == librosa_filterbank_torch.shape}")
print(f"nearly equal: {is_close}")
if not is_close:
difference = torch.abs(torch_filterbank - librosa_filterbank_torch)
max_diff = torch.max(difference).item()
print(f"Maximum difference between filterbanks: {max_diff}")
print("Sample differences (first 5 filters):")
print(difference[:5, :10])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment