Last active
December 17, 2018 18:58
-
-
Save lukegre/2353257680fe50ef96b2b0dd53d7a8b0 to your computer and use it in GitHub Desktop.
Empirical mode decomposition (EMD) for 1D data with plotting function for output IMFs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
### Implementation of our EMD() function | |
def emd(data, stop_limit=0.001, spline_order=3): | |
""" | |
EMD as explained by Scott Cole (https://srcole.github.io/2016/01/18/emd/), | |
but has been modified by Luke Gregor (https://github.com/luke-gregor) to | |
automatically stop when there are no more IMFs in the dataset. | |
Additionally, the residual from the stopping point and the original dataset | |
is counted as the final IMF. | |
The first IMF is the most noisy with the last being the smoothest. | |
The sum of the IMFs is equivalent to the input data. | |
You can change the spline_order (univariate spline) and | |
the stop_limit (based on error). | |
""" | |
from pandas import DataFrame, Series | |
from scipy import signal | |
r = np.array(data) | |
time = np.arange(r.size) | |
imfs = np.ndarray([time.size, 0]) * np.NaN | |
assert 1 < spline_order < 6, "spline_order must be between 2 and 5" | |
enough_extrema = True | |
while enough_extrema: | |
r_t = r | |
not_imf = True | |
while not_imf: | |
# Identification of Peaks and Troughs in the data | |
peak_idx = signal.argrelmax(r_t)[0] # peak index | |
trof_idx = signal.argrelmin(r_t)[0] # trough index | |
if np.min([peak_idx.size, trof_idx.size]) <= spline_order: | |
enough_extrema = False | |
break | |
peak_val = r_t[peak_idx] # peak value | |
trof_val = r_t[trof_idx] # trough value | |
peak_t = _cubic_spline(peak_idx, peak_val, time, k=spline_order) # interpolated peaks | |
trof_t = _cubic_spline(trof_idx, trof_val, time, k=spline_order) # interpolated troughs | |
# returns the mean of interpolated peaks and troughs with ends replaced | |
mean_t = _boundary_replacement((peak_t + trof_t) / 2, peak_idx, trof_idx) | |
# Assess if this is an IMF (only look in time between peaks and troughs) | |
err = _component_error(r_t, mean_t, peak_idx, trof_idx) | |
# if not imf, update r_t and is_imf | |
if err < stop_limit: | |
not_imf = False | |
imfs = np.c_[imfs, r_t] | |
r = r - imfs[:, -1] | |
else: | |
r_t = r_t - mean_t | |
imfs = np.c_[imfs, data - imfs.sum(1)] | |
if isinstance(data, Series): | |
return DataFrame(imfs, index=data.index) | |
else: | |
return DataFrame(imfs) | |
def _cubic_spline(x, y, xi, k): | |
from scipy.interpolate import InterpolatedUnivariateSpline | |
# k is set to 3rd degree for cubic spline | |
func = InterpolatedUnivariateSpline(x, y, k=k) | |
return func(xi) | |
def _boundary_replacement(arr, peak_indices, trough_indices): | |
""" | |
Finds the first and last trough and assigns the first/last | |
value at the index to the head/tail of the time series. | |
""" | |
import numpy as np | |
# find the start and end boundaries of the EMD peaks/troughs | |
si = np.max([peak_indices.min(), trough_indices.min()]) # start index | |
ei = np.min([peak_indices.max(), trough_indices.max()]) + 1 # end index | |
arr[:si] = arr[si] # first value | |
arr[ei:] = arr[ei] # last value | |
return arr | |
def _component_error(orig, extrema_avg, peak_indices, trough_indices): | |
""" | |
Calculates the normalized error of the current component. | |
""" | |
# find the start and end boundaries of the EMD peaks/troughs | |
si = np.max([peak_indices.min(), trough_indices.min()]) # start index | |
ei = np.min([peak_indices.max(), trough_indices.max()]) + 1 # end index | |
# the square sum of the current and original component | |
current_component = (extrema_avg[si:ei]**2).mean() | |
original_component = (orig[si:ei]**2).mean() | |
return current_component / original_component | |
def plot_imfs(imf): | |
""" | |
Creates a timeseries plot of IMFs in a pandas.DataFrame, | |
where each column is an IMF. | |
""" | |
import pandas as pd | |
imf_plot = imf.copy() | |
# normalise original data to mean 0 | |
orig = imf_plot.sum(1) | |
orig -= orig.mean() | |
# normalise last column to mean 0 | |
imf_plot.iloc[:, -1] -= imf_plot.iloc[:, -1].mean() | |
# create plot | |
ax = imf_plot.plot( | |
legend=False, | |
subplots=True, | |
layout=[imf.shape[1], 1], | |
figsize=[8, imf.shape[1] * 2.3], | |
sharey=True, | |
sharex=True, | |
color='k', | |
zorder=5, | |
).reshape(-1) | |
# plot the original data on the same axes for every plot | |
if isinstance(imf.index, pd.DatetimeIndex): | |
x = imf.index.to_period() | |
else: | |
x = imf.index.values | |
[a.plot(x, orig.values, c='#CCCCCC') for a in ax] | |
# set the ylabels | |
[a.set_ylabel('IMF {}'.format(i+1)) for i, a in enumerate(ax)] | |
# label the figure with a note about the noramlising | |
ax[0].set_title("EMD for data with {} IMFs\n(original data" | |
" and last IMF have been normalised with mean 0)" | |
"".format(imf.shape[1])) | |
ax[-1].set_xlabel('') | |
# get the figure and set dpi | |
fig = ax[0].get_figure() | |
fig.set_dpi(120) | |
fig.tight_layout() | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment