Created
January 23, 2020 14:21
-
-
Save ground0state/2e535fc1caf01495b2d107e4aa8a0280 to your computer and use it in GitHub Desktop.
This code is from https://www.kaggle.com/jackvial/dwt-signal-denoising/data.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy import signal | |
from scipy.signal import butter | |
import pywt | |
import matplotlib.pyplot as plt | |
def high_pass_filter(x, low_cutoff=10, sample_rate=sample_rate): | |
""" | |
From @randxie https://github.com/randxie/Kaggle-VSB-Baseline/blob/master/src/utils/util_signal.py | |
Modified to work with scipy version 1.1.0 which does not have the fs parameter | |
""" | |
# nyquist frequency is half the sample rate https://en.wikipedia.org/wiki/Nyquist_frequency | |
nyquist = 0.5 * sample_rate | |
norm_low_cutoff = low_cutoff / nyquist | |
# Fault pattern usually exists in high frequency band. According to literature, the pattern is visible above 10^4 Hz. | |
# The order of the filter | |
N = 10 | |
# sos is [N/2] rows, 6columns filter coefficient (https://jp.mathworks.com/help/signal/ref/sosfilt.html) | |
sos = butter(N, Wn=[norm_low_cutoff], btype='highpass', output='sos') | |
# The output of the digital filter generated by a digital IIR filter defined by sos. | |
filtered_sig = signal.sosfilt(sos, x) | |
return filtered_sig | |
def maddest(d, axis=None): | |
""" | |
Mean Absolute Deviation | |
""" | |
return np.mean(np.absolute(d - np.mean(d, axis)), axis) | |
def denoise_signal(x, wavelet='db4', level=1): | |
""" | |
1. Adapted from waveletSmooth function found here: | |
http://connor-johnson.com/2016/01/24/using-pywavelets-to-remove-high-frequency-noise/ | |
2. Threshold equation and using hard mode in threshold as mentioned | |
in section '3.2 denoising based on optimized singular values' from paper by Tomas Vantuch: | |
http://dspace.vsb.cz/bitstream/handle/10084/133114/VAN431_FEI_P1807_1801V001_2018.pdf | |
""" | |
# Decompose to get the wavelet coefficients vector | |
coeff = pywt.wavedec(x, wavelet, mode="per") | |
# Calculate sigma for threshold as defined in http://dspace.vsb.cz/bitstream/handle/10084/133114/VAN431_FEI_P1807_1801V001_2018.pdf p.20 | |
# As noted by @harshit92 MAD referred to in the paper is Mean Absolute Deviation not Median Absolute Deviation | |
sigma = (1/0.6745) * maddest(coeff[-level]) | |
# Calculte the univeral threshold | |
uthresh = sigma * np.sqrt(2*np.log(len(x))) | |
# data under threshold are replaced with substitute. | |
# filter data without approximation coefficients array. | |
coeff[1:] = (pywt.threshold(i, value=uthresh, mode='hard', substitute=0) | |
for i in coeff[1:]) | |
# Reconstruct the signal using the thresholded coefficients | |
rec = pywt.waverec(coeff, wavelet, mode='per') | |
return rec | |
# ノイズデータ作成 | |
x = np.linspace(0, 10, 100) | |
yorg = np.sin(x) | |
y = yorg + np.random.randn(100)*0.2 | |
plt.plot(x, yorg, 'r', label='オリジナルsin') | |
plt.plot(x, y, 'k-', label='元系列') | |
plt.legend() | |
plt.show() | |
# ノイズ除去 | |
n_samples = 100 | |
sample_duration = 1 | |
sample_rate = n_samples * (1 / sample_duration) | |
x_hp = high_pass_filter( | |
y, low_cutoff=0.05, sample_rate=sample_rate) | |
x_dn = denoise_signal(x_hp, wavelet='haar', level=1) | |
plt.plot(x, yorg, 'r', label='オリジナルsin') | |
plt.plot(x, y, 'k-', label='元系列') | |
plt.plot(x, x_dn, 'k-', label='ノイズ除去') | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment