Skip to content

Instantly share code, notes, and snippets.

@lostanlen
Last active June 26, 2019 21:54
Show Gist options
  • Save lostanlen/01816b8e3852e783b00aeb479dfc8364 to your computer and use it in GitHub Desktop.
Save lostanlen/01816b8e3852e783b00aeb479dfc8364 to your computer and use it in GitHub Desktop.
more efficient length trimming for librosa inverters
import librosa
from librosa import *
from librosa.core.spectrum import __overlap_add
from librosa.filters import get_window, window_sumsquare
import numpy as np
from numba import jit
def istft(stft_matrix, hop_length=None, win_length=None, window='hann',
center=True, dtype=np.float32, length=None):
n_fft = 2 * (stft_matrix.shape[0] - 1)
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
ifft_window = get_window(window, win_length, fftbins=True)
# Pad out to match n_fft, and add a broadcasting axis
ifft_window = util.pad_center(ifft_window, n_fft)[:, np.newaxis]
n_frames = stft_matrix.shape[1]
expected_signal_len = n_fft + hop_length * (n_frames - 1)
y = np.zeros(expected_signal_len, dtype=dtype)
n_columns = int(util.MAX_MEM_BLOCK // (stft_matrix.shape[0] *
stft_matrix.itemsize))
fft = get_fftlib()
frame = 0
for bl_s in range(0, n_frames, n_columns):
bl_t = min(bl_s + n_columns, n_frames)
# invert the block and apply the window function
ytmp = ifft_window * fft.irfft(stft_matrix[:, bl_s:bl_t], axis=0)
# Overlap-add the istft block starting at the i'th frame
__overlap_add(y[frame * hop_length:], ytmp, hop_length)
frame += (bl_t - bl_s)
# Normalize by sum of squared window
ifft_window_sum = window_sumsquare(window,
n_frames,
win_length=win_length,
n_fft=n_fft,
hop_length=hop_length,
dtype=dtype)
approx_nonzero_indices = ifft_window_sum > util.tiny(ifft_window_sum)
y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices]
if length is None:
# If we don't need to control length, just do the usual center trimming
# to eliminate padded data
if center:
y = y[int(n_fft // 2):-int(n_fft // 2)]
else:
if center:
# If we're centering, crop off the first n_fft//2 samples
# and then trim/pad to the target length.
# We don't trim the end here, so that if the signal is zero-padded
# to a longer duration, the decay is smooth by windowing
start = int(n_fft // 2)
else:
# If we're not centering, start at 0 and trim/pad as necessary
start = 0
y = util.fix_length(y[start:], length)
return y
def istft898(stft_matrix, hop_length=None, win_length=None, window='hann',
center=True, dtype=np.float32, length=None):
n_fft = 2 * (stft_matrix.shape[0] - 1)
# By default, use the entire frame
if win_length is None:
win_length = n_fft
# Set the default hop, if it's not already specified
if hop_length is None:
hop_length = int(win_length // 4)
ifft_window = get_window(window, win_length, fftbins=True)
# Pad out to match n_fft, and add a broadcasting axis
ifft_window = util.pad_center(ifft_window, n_fft)[:, np.newaxis]
# For efficiency, trim STFT frames according to signal length if available
if length:
if center:
padded_length = length + int(n_fft)
else:
padded_length = length
n_frames = min(stft_matrix.shape[1], int(np.ceil(padded_length / hop_length)))
else:
n_frames = stft_matrix.shape[1]
expected_signal_len = n_fft + hop_length * (n_frames - 1)
y = np.zeros(expected_signal_len, dtype=dtype)
n_columns = int(util.MAX_MEM_BLOCK // (stft_matrix.shape[0] *
stft_matrix.itemsize))
fft = get_fftlib()
frame = 0
for bl_s in range(0, n_frames, n_columns):
bl_t = min(bl_s + n_columns, n_frames)
# invert the block and apply the window function
ytmp = ifft_window * fft.irfft(stft_matrix[:, bl_s:bl_t], axis=0)
# Overlap-add the istft block starting at the i'th frame
__overlap_add(y[frame * hop_length:], ytmp, hop_length)
frame += (bl_t - bl_s)
# Normalize by sum of squared window
ifft_window_sum = window_sumsquare(window,
n_frames,
win_length=win_length,
n_fft=n_fft,
hop_length=hop_length,
dtype=dtype)
approx_nonzero_indices = ifft_window_sum > util.tiny(ifft_window_sum)
y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices]
if length is None:
# If we don't need to control length, just do the usual center trimming
# to eliminate padded data
if center:
y = y[int(n_fft // 2):-int(n_fft // 2)]
else:
if center:
# If we're centering, crop off the first n_fft//2 samples
# and then trim/pad to the target length.
# We don't trim the end here, so that if the signal is zero-padded
# to a longer duration, the decay is smooth by windowing
start = int(n_fft // 2)
else:
# If we're not centering, start at 0 and trim/pad as necessary
start = 0
y = util.fix_length(y[start:], length)
return y
import time
import tqdm
n_list = np.arange(0, len(y), 65536)[1:].astype('int')
n_fft = 2048
elapsed_v06 = []
elapsed_v07 = []
y = y[:n_list[-1]]
y_pad = librosa.util.fix_length(y, len(y) + n_fft // 2)
D = librosa.stft(y_pad, n_fft=n_fft)
for n in tqdm.tqdm(n_list):
elapsed = %timeit -q -r 10 -o istft(D, length=n)
elapsed_v06.append(elapsed)
elapsed = %timeit -q -r 10 -o istft898(D, length=n)
elapsed_v07.append(elapsed)
import matplotlib
%matplotlib inline
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 5))
n_times = 100 * np.array(n_list)/n_list[-1]
normalizer = 0.01 * elapsed_v06[-1].average
plt.plot(n_times,
np.array([x.average for x in elapsed_v06]) / normalizer,
'-o', label='librosa v0.6')
plt.fill_between(n_times,
np.array([x.average - x.stdev for x in elapsed_v06]) / normalizer,
np.array([x.average + x.stdev for x in elapsed_v06]) / normalizer,
alpha=0.5)
plt.plot(n_times,
np.array([x.average for x in elapsed_v07]) / normalizer,
'-o', label='librosa v0.7')
plt.fill_between(n_times,
np.array([x.average - x.stdev for x in elapsed_v07])/ normalizer,
np.array([x.average + x.stdev for x in elapsed_v07])/ normalizer,
alpha=0.5)
plt.gca().set_xticks(np.arange(0, 110, 10))
plt.gca().set_yticks(np.arange(0, 600, 100))
plt.xlim(0, 100)
plt.ylim(0.0, 600.0)
plt.grid(linestyle='--')
plt.legend(loc='upper right')
plt.xlabel("Input duration (% of total)")
plt.ylabel("Computation time (% of total)")
from scipy.signal import bartlett, hann, hamming, blackman, blackmanharris
import os
import soundfile as sf
import six
def srand(seed=628318530):
np.random.seed(seed)
pass
def __test(x, n_fft, hop_length, window, atol, length):
S = librosa.core.stft(
x, n_fft=n_fft, hop_length=hop_length, window=window)
x_reconstructed = istft898(
S, hop_length=hop_length, window=window, length=length)
if length is not None:
assert len(x_reconstructed) == length
L = min(len(x), len(x_reconstructed))
x = np.resize(x, L)
x_reconstructed = np.resize(x_reconstructed, L)
# NaN/Inf/-Inf should not happen
assert np.all(np.isfinite(x_reconstructed))
# should be almost approximately reconstucted
print(n_fft, hop_length, atol, length)
assert np.allclose(x, x_reconstructed, atol=atol)
srand()
# White noise
x1 = np.random.randn(2 ** 15)
# Sin wave
x2 = np.sin(np.linspace(-np.pi, np.pi, 2 ** 15))
# Real music signal
x3, sr = librosa.load(os.path.join(os.path.split(os.path.split(
librosa.__file__)[0])[0], 'tests', 'data', 'test1_44100.wav'),
sr=None, mono=True)
assert sr == 44100
for x, atol in [(x1, 1.0e-6), (x2, 1.0e-7), (x3, 1.0e-7)]:
for window_func in [bartlett, hann, hamming, blackman, blackmanharris]:
for n_fft in [512, 1024, 2048, 4096]:
win = window_func(n_fft, sym=False)
symwin = window_func(n_fft, sym=True)
# tests with pre-computed window fucntions
for hop_length_denom in six.moves.range(2, 9):
hop_length = n_fft // hop_length_denom
for length in [None, len(x) - 1000, len(x + 1000)]:
__test(x, n_fft, hop_length, win, atol, length)
__test(x, n_fft, hop_length, symwin, atol, length)
# also tests with passing widnow function itself
__test(x, n_fft, n_fft // 9, window_func, atol, None)
# test with default paramters
x_reconstructed = istft898(librosa.core.stft(x))
L = min(len(x), len(x_reconstructed))
x = np.resize(x, L)
x_reconstructed = np.resize(x_reconstructed, L)
assert np.allclose(x, x_reconstructed, atol=atol)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment