Skip to content

Instantly share code, notes, and snippets.

@lukegre
Last active December 17, 2018 18:58
Show Gist options
  • Save lukegre/2353257680fe50ef96b2b0dd53d7a8b0 to your computer and use it in GitHub Desktop.
Save lukegre/2353257680fe50ef96b2b0dd53d7a8b0 to your computer and use it in GitHub Desktop.
Empirical mode decomposition (EMD) for 1D data with plotting function for output IMFs
import numpy as np
### Implementation of our EMD() function
def emd(data, stop_limit=0.001, spline_order=3):
"""
EMD as explained by Scott Cole (https://srcole.github.io/2016/01/18/emd/),
but has been modified by Luke Gregor (https://github.com/luke-gregor) to
automatically stop when there are no more IMFs in the dataset.
Additionally, the residual from the stopping point and the original dataset
is counted as the final IMF.
The first IMF is the most noisy with the last being the smoothest.
The sum of the IMFs is equivalent to the input data.
You can change the spline_order (univariate spline) and
the stop_limit (based on error).
"""
from pandas import DataFrame, Series
from scipy import signal
r = np.array(data)
time = np.arange(r.size)
imfs = np.ndarray([time.size, 0]) * np.NaN
assert 1 < spline_order < 6, "spline_order must be between 2 and 5"
enough_extrema = True
while enough_extrema:
r_t = r
not_imf = True
while not_imf:
# Identification of Peaks and Troughs in the data
peak_idx = signal.argrelmax(r_t)[0] # peak index
trof_idx = signal.argrelmin(r_t)[0] # trough index
if np.min([peak_idx.size, trof_idx.size]) <= spline_order:
enough_extrema = False
break
peak_val = r_t[peak_idx] # peak value
trof_val = r_t[trof_idx] # trough value
peak_t = _cubic_spline(peak_idx, peak_val, time, k=spline_order) # interpolated peaks
trof_t = _cubic_spline(trof_idx, trof_val, time, k=spline_order) # interpolated troughs
# returns the mean of interpolated peaks and troughs with ends replaced
mean_t = _boundary_replacement((peak_t + trof_t) / 2, peak_idx, trof_idx)
# Assess if this is an IMF (only look in time between peaks and troughs)
err = _component_error(r_t, mean_t, peak_idx, trof_idx)
# if not imf, update r_t and is_imf
if err < stop_limit:
not_imf = False
imfs = np.c_[imfs, r_t]
r = r - imfs[:, -1]
else:
r_t = r_t - mean_t
imfs = np.c_[imfs, data - imfs.sum(1)]
if isinstance(data, Series):
return DataFrame(imfs, index=data.index)
else:
return DataFrame(imfs)
def _cubic_spline(x, y, xi, k):
from scipy.interpolate import InterpolatedUnivariateSpline
# k is set to 3rd degree for cubic spline
func = InterpolatedUnivariateSpline(x, y, k=k)
return func(xi)
def _boundary_replacement(arr, peak_indices, trough_indices):
"""
Finds the first and last trough and assigns the first/last
value at the index to the head/tail of the time series.
"""
import numpy as np
# find the start and end boundaries of the EMD peaks/troughs
si = np.max([peak_indices.min(), trough_indices.min()]) # start index
ei = np.min([peak_indices.max(), trough_indices.max()]) + 1 # end index
arr[:si] = arr[si] # first value
arr[ei:] = arr[ei] # last value
return arr
def _component_error(orig, extrema_avg, peak_indices, trough_indices):
"""
Calculates the normalized error of the current component.
"""
# find the start and end boundaries of the EMD peaks/troughs
si = np.max([peak_indices.min(), trough_indices.min()]) # start index
ei = np.min([peak_indices.max(), trough_indices.max()]) + 1 # end index
# the square sum of the current and original component
current_component = (extrema_avg[si:ei]**2).mean()
original_component = (orig[si:ei]**2).mean()
return current_component / original_component
def plot_imfs(imf):
"""
Creates a timeseries plot of IMFs in a pandas.DataFrame,
where each column is an IMF.
"""
import pandas as pd
imf_plot = imf.copy()
# normalise original data to mean 0
orig = imf_plot.sum(1)
orig -= orig.mean()
# normalise last column to mean 0
imf_plot.iloc[:, -1] -= imf_plot.iloc[:, -1].mean()
# create plot
ax = imf_plot.plot(
legend=False,
subplots=True,
layout=[imf.shape[1], 1],
figsize=[8, imf.shape[1] * 2.3],
sharey=True,
sharex=True,
color='k',
zorder=5,
).reshape(-1)
# plot the original data on the same axes for every plot
if isinstance(imf.index, pd.DatetimeIndex):
x = imf.index.to_period()
else:
x = imf.index.values
[a.plot(x, orig.values, c='#CCCCCC') for a in ax]
# set the ylabels
[a.set_ylabel('IMF {}'.format(i+1)) for i, a in enumerate(ax)]
# label the figure with a note about the noramlising
ax[0].set_title("EMD for data with {} IMFs\n(original data"
" and last IMF have been normalised with mean 0)"
"".format(imf.shape[1]))
ax[-1].set_xlabel('')
# get the figure and set dpi
fig = ax[0].get_figure()
fig.set_dpi(120)
fig.tight_layout()
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment