Skip to content

Instantly share code, notes, and snippets.

@jrsa
Created September 26, 2016 19:18
Show Gist options
  • Save jrsa/1873cd2f0f284174de2ec3203551c4c0 to your computer and use it in GitHub Desktop.
Save jrsa/1873cd2f0f284174de2ec3203551c4c0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""stft.py
Provides stft, a Short-Time Fourier Transform function.
Author: James Anderson
Email: [email protected]
Date: 09/24/2016
CalArts MTEC-480/680
Fall 2016
"""
from scipy.fftpack import fft
from scipy.signal import get_window
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import sys
import scipy.io.wavfile as wav
def stft(x, N=None, hop=None, win=None, K=None):
"""Short-time Fourier transform
Parameters
----------
x : numpy.ndarray
Vector of samples to analyze (1 channel)
N : int, optional
Window size for each FFT. Default 1024.
hop : int, optional
Hop size between successive windows. Default N//2.
win : string, float, tuple, or list, optional
Window type and/or parameters (c.f. scipy.signal.get_window)
Default 'hann'.
The window function can be manually given as a list.
K : int, optional
FFT Size, frames will be padded up to this size
Returns
-------
numpy.ndarray
2D array of complex spectral values over time
"""
if N is None:
N = 1024
if hop is None:
hop = N//2
if win is None:
win = 'hann'
if K is None:
K = N
elif K < N:
print("K cant be smaller than window size")
K = N
if isinstance(win, list):
window_vector = win
else:
window_vector = get_window(win, N)
hopN = len(x) // hop
X = np.ndarray((hopN, N), dtype=complex)
for i in range(hopN):
hop_offset = i * hop
input_frame = x[hop_offset:hop_offset+N]
if len(input_frame) != N:
padding = N-len(input_frame)
# pad frame if we are on the end of the input vector or
# if a larger, padded fft is desired
input_frame = np.concatenate([input_frame, np.zeros(padding)])
X[i] = fft(input_frame * window_vector)
return X
def plot(X):
# plot log and rotate so time is horizontal
plt.imshow(20 * np.log10(X).astype('float').T, cmap=matplotlib.cm.jet, interpolation='nearest')
def main(argv=None):
matplotlib.style.use('ggplot')
if argv is None:
argv = sys.argv
try:
fn = sys.argv[1]
except IndexError as e:
fn = '../portrait_dry.wav'
(fs, x) = wav.read(fn)
X = stft(x, N=2048, hop=128)
plot(X)
plt.show()
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment