Last active
June 26, 2019 21:54
-
-
Save lostanlen/01816b8e3852e783b00aeb479dfc8364 to your computer and use it in GitHub Desktop.
more efficient length trimming for librosa inverters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import librosa | |
from librosa import * | |
from librosa.core.spectrum import __overlap_add | |
from librosa.filters import get_window, window_sumsquare | |
import numpy as np | |
from numba import jit | |
def istft(stft_matrix, hop_length=None, win_length=None, window='hann', | |
center=True, dtype=np.float32, length=None): | |
n_fft = 2 * (stft_matrix.shape[0] - 1) | |
# By default, use the entire frame | |
if win_length is None: | |
win_length = n_fft | |
# Set the default hop, if it's not already specified | |
if hop_length is None: | |
hop_length = int(win_length // 4) | |
ifft_window = get_window(window, win_length, fftbins=True) | |
# Pad out to match n_fft, and add a broadcasting axis | |
ifft_window = util.pad_center(ifft_window, n_fft)[:, np.newaxis] | |
n_frames = stft_matrix.shape[1] | |
expected_signal_len = n_fft + hop_length * (n_frames - 1) | |
y = np.zeros(expected_signal_len, dtype=dtype) | |
n_columns = int(util.MAX_MEM_BLOCK // (stft_matrix.shape[0] * | |
stft_matrix.itemsize)) | |
fft = get_fftlib() | |
frame = 0 | |
for bl_s in range(0, n_frames, n_columns): | |
bl_t = min(bl_s + n_columns, n_frames) | |
# invert the block and apply the window function | |
ytmp = ifft_window * fft.irfft(stft_matrix[:, bl_s:bl_t], axis=0) | |
# Overlap-add the istft block starting at the i'th frame | |
__overlap_add(y[frame * hop_length:], ytmp, hop_length) | |
frame += (bl_t - bl_s) | |
# Normalize by sum of squared window | |
ifft_window_sum = window_sumsquare(window, | |
n_frames, | |
win_length=win_length, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
dtype=dtype) | |
approx_nonzero_indices = ifft_window_sum > util.tiny(ifft_window_sum) | |
y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices] | |
if length is None: | |
# If we don't need to control length, just do the usual center trimming | |
# to eliminate padded data | |
if center: | |
y = y[int(n_fft // 2):-int(n_fft // 2)] | |
else: | |
if center: | |
# If we're centering, crop off the first n_fft//2 samples | |
# and then trim/pad to the target length. | |
# We don't trim the end here, so that if the signal is zero-padded | |
# to a longer duration, the decay is smooth by windowing | |
start = int(n_fft // 2) | |
else: | |
# If we're not centering, start at 0 and trim/pad as necessary | |
start = 0 | |
y = util.fix_length(y[start:], length) | |
return y | |
def istft898(stft_matrix, hop_length=None, win_length=None, window='hann', | |
center=True, dtype=np.float32, length=None): | |
n_fft = 2 * (stft_matrix.shape[0] - 1) | |
# By default, use the entire frame | |
if win_length is None: | |
win_length = n_fft | |
# Set the default hop, if it's not already specified | |
if hop_length is None: | |
hop_length = int(win_length // 4) | |
ifft_window = get_window(window, win_length, fftbins=True) | |
# Pad out to match n_fft, and add a broadcasting axis | |
ifft_window = util.pad_center(ifft_window, n_fft)[:, np.newaxis] | |
# For efficiency, trim STFT frames according to signal length if available | |
if length: | |
if center: | |
padded_length = length + int(n_fft) | |
else: | |
padded_length = length | |
n_frames = min(stft_matrix.shape[1], int(np.ceil(padded_length / hop_length))) | |
else: | |
n_frames = stft_matrix.shape[1] | |
expected_signal_len = n_fft + hop_length * (n_frames - 1) | |
y = np.zeros(expected_signal_len, dtype=dtype) | |
n_columns = int(util.MAX_MEM_BLOCK // (stft_matrix.shape[0] * | |
stft_matrix.itemsize)) | |
fft = get_fftlib() | |
frame = 0 | |
for bl_s in range(0, n_frames, n_columns): | |
bl_t = min(bl_s + n_columns, n_frames) | |
# invert the block and apply the window function | |
ytmp = ifft_window * fft.irfft(stft_matrix[:, bl_s:bl_t], axis=0) | |
# Overlap-add the istft block starting at the i'th frame | |
__overlap_add(y[frame * hop_length:], ytmp, hop_length) | |
frame += (bl_t - bl_s) | |
# Normalize by sum of squared window | |
ifft_window_sum = window_sumsquare(window, | |
n_frames, | |
win_length=win_length, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
dtype=dtype) | |
approx_nonzero_indices = ifft_window_sum > util.tiny(ifft_window_sum) | |
y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices] | |
if length is None: | |
# If we don't need to control length, just do the usual center trimming | |
# to eliminate padded data | |
if center: | |
y = y[int(n_fft // 2):-int(n_fft // 2)] | |
else: | |
if center: | |
# If we're centering, crop off the first n_fft//2 samples | |
# and then trim/pad to the target length. | |
# We don't trim the end here, so that if the signal is zero-padded | |
# to a longer duration, the decay is smooth by windowing | |
start = int(n_fft // 2) | |
else: | |
# If we're not centering, start at 0 and trim/pad as necessary | |
start = 0 | |
y = util.fix_length(y[start:], length) | |
return y | |
import time | |
import tqdm | |
n_list = np.arange(0, len(y), 65536)[1:].astype('int') | |
n_fft = 2048 | |
elapsed_v06 = [] | |
elapsed_v07 = [] | |
y = y[:n_list[-1]] | |
y_pad = librosa.util.fix_length(y, len(y) + n_fft // 2) | |
D = librosa.stft(y_pad, n_fft=n_fft) | |
for n in tqdm.tqdm(n_list): | |
elapsed = %timeit -q -r 10 -o istft(D, length=n) | |
elapsed_v06.append(elapsed) | |
elapsed = %timeit -q -r 10 -o istft898(D, length=n) | |
elapsed_v07.append(elapsed) | |
import matplotlib | |
%matplotlib inline | |
from matplotlib import pyplot as plt | |
plt.figure(figsize=(6, 5)) | |
n_times = 100 * np.array(n_list)/n_list[-1] | |
normalizer = 0.01 * elapsed_v06[-1].average | |
plt.plot(n_times, | |
np.array([x.average for x in elapsed_v06]) / normalizer, | |
'-o', label='librosa v0.6') | |
plt.fill_between(n_times, | |
np.array([x.average - x.stdev for x in elapsed_v06]) / normalizer, | |
np.array([x.average + x.stdev for x in elapsed_v06]) / normalizer, | |
alpha=0.5) | |
plt.plot(n_times, | |
np.array([x.average for x in elapsed_v07]) / normalizer, | |
'-o', label='librosa v0.7') | |
plt.fill_between(n_times, | |
np.array([x.average - x.stdev for x in elapsed_v07])/ normalizer, | |
np.array([x.average + x.stdev for x in elapsed_v07])/ normalizer, | |
alpha=0.5) | |
plt.gca().set_xticks(np.arange(0, 110, 10)) | |
plt.gca().set_yticks(np.arange(0, 600, 100)) | |
plt.xlim(0, 100) | |
plt.ylim(0.0, 600.0) | |
plt.grid(linestyle='--') | |
plt.legend(loc='upper right') | |
plt.xlabel("Input duration (% of total)") | |
plt.ylabel("Computation time (% of total)") | |
from scipy.signal import bartlett, hann, hamming, blackman, blackmanharris | |
import os | |
import soundfile as sf | |
import six | |
def srand(seed=628318530): | |
np.random.seed(seed) | |
pass | |
def __test(x, n_fft, hop_length, window, atol, length): | |
S = librosa.core.stft( | |
x, n_fft=n_fft, hop_length=hop_length, window=window) | |
x_reconstructed = istft898( | |
S, hop_length=hop_length, window=window, length=length) | |
if length is not None: | |
assert len(x_reconstructed) == length | |
L = min(len(x), len(x_reconstructed)) | |
x = np.resize(x, L) | |
x_reconstructed = np.resize(x_reconstructed, L) | |
# NaN/Inf/-Inf should not happen | |
assert np.all(np.isfinite(x_reconstructed)) | |
# should be almost approximately reconstucted | |
print(n_fft, hop_length, atol, length) | |
assert np.allclose(x, x_reconstructed, atol=atol) | |
srand() | |
# White noise | |
x1 = np.random.randn(2 ** 15) | |
# Sin wave | |
x2 = np.sin(np.linspace(-np.pi, np.pi, 2 ** 15)) | |
# Real music signal | |
x3, sr = librosa.load(os.path.join(os.path.split(os.path.split( | |
librosa.__file__)[0])[0], 'tests', 'data', 'test1_44100.wav'), | |
sr=None, mono=True) | |
assert sr == 44100 | |
for x, atol in [(x1, 1.0e-6), (x2, 1.0e-7), (x3, 1.0e-7)]: | |
for window_func in [bartlett, hann, hamming, blackman, blackmanharris]: | |
for n_fft in [512, 1024, 2048, 4096]: | |
win = window_func(n_fft, sym=False) | |
symwin = window_func(n_fft, sym=True) | |
# tests with pre-computed window fucntions | |
for hop_length_denom in six.moves.range(2, 9): | |
hop_length = n_fft // hop_length_denom | |
for length in [None, len(x) - 1000, len(x + 1000)]: | |
__test(x, n_fft, hop_length, win, atol, length) | |
__test(x, n_fft, hop_length, symwin, atol, length) | |
# also tests with passing widnow function itself | |
__test(x, n_fft, n_fft // 9, window_func, atol, None) | |
# test with default paramters | |
x_reconstructed = istft898(librosa.core.stft(x)) | |
L = min(len(x), len(x_reconstructed)) | |
x = np.resize(x, L) | |
x_reconstructed = np.resize(x_reconstructed, L) | |
assert np.allclose(x, x_reconstructed, atol=atol) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment