Created
August 21, 2017 16:56
-
-
Save paultsw/666711643442d8283252011a93b88241 to your computer and use it in GitHub Desktop.
Gaussian process regression for 1d signals with sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Fit a Gaussian process to a signal using SKLearn. | |
""" | |
import numpy as np | |
from sklearn import gaussian_process | |
from sklearn.gaussian_process import kernels as K | |
import matplotlib.pyplot as plt | |
from scipy.signal import resample | |
import argparse | |
def clip_signal(sig, tol): | |
"""Clip a signal within the bounds indicated by `tol`.""" | |
centre = np.mean(sig) | |
return np.clip(sig, centre-tol, centre+tol) | |
def build_kernel(): | |
""" | |
Construct a custom kernel. | |
[N.B.: if you want to experiment with different kernels, this is the only function you should change.] | |
""" | |
return K.RBF(length_scale=10.) + K.WhiteKernel(noise_level=20.) | |
def main(): | |
### parse CLI args: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--clip", dest='clip', type=bool, default=False, | |
help="If True, clip any samples that fall outside of 3 STDVs. [Default: False]") | |
parser.add_argument("--subsample", dest='subsample', type=int, default=5000, | |
help="Number of points to subsample from read; if 0, use whole read. [Default: 5000]") | |
parser.add_argument("signal_file") | |
args = parser.parse_args() | |
### load signal: | |
signal = np.load(args.signal_file) | |
### optionally clip the open-pore signal at the start and end; | |
### (if this is done, remove everything 2 stdvs away from the mean.) | |
if args.clip: signal = clip_signal(signal, 3*np.std(signal)) | |
### optionally subsample: | |
### (if S := subsample > 0, subsample S points from signal) | |
if args.subsample > 0: signal = resample(signal, args.subsample) | |
### compute signal statistics: | |
sig_max = np.amax(signal) | |
sig_min = np.amin(signal) | |
x_ticks = np.linspace(start=0, stop=(2*len(signal)), num=len(signal)) | |
### construct kernel: | |
kernel = build_kernel() | |
### perform kriging on the signal data: | |
gpr = gaussian_process.GaussianProcessRegressor( | |
kernel=kernel, | |
optimizer='fmin_l_bfgs_b', | |
n_restarts_optimizer=5, | |
normalize_y=True) | |
_X = x_ticks.reshape(-1,1) | |
_y = signal.reshape(-1,1) | |
print("Fitting to dataset... Be patient, this may take a while.") | |
gpr.fit(_X,_y) | |
print("...Done. Generating predicted mean curve and plotting...") | |
predictions = gpr.predict(_X).reshape(-1) | |
### plot GPR predictions: | |
plt.plot(x_ticks, signal, 'o') | |
# (un-)comment the next two lines to hide/show 3sigma bounding lines: | |
#plt.plot(x_ticks, np.ones_like(x_ticks) * (np.mean(signal)+3*np.std(signal)), '-') | |
#plt.plot(x_ticks, np.ones_like(x_ticks) * (np.mean(signal)-3*np.std(signal)), '-') | |
plt.plot(x_ticks, predictions, '-') | |
plt.xlim([-1, len(signal)*2+1]) | |
plt.ylim([sig_min-1, sig_max+1]) | |
plt.show() | |
# run only when called from CLI: | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment