#!/usr/bin/env python3 |
# -*- coding: utf-8 -*- |
import numpy as np |
from scipy import stats |
import matplotlib.pyplot as plt |
from fda.basis import FDataBasis, BSpline |
from fda.grid import FDataGrid |
from fda.registration import shift_registration |
# Data parameters |
nsamples = 9 # Number of samples |
nfine = 100 # Number of points per sample |
sd = .1 |
shift_sd = .05 # Standard deviation of phase variation |
amp_sd = 0 #.1 # Standard deviation of amplitude variation |
error_sd = .05 # Standard deviation of gaussian noise |
xlim = (-1, 1) # Domain range |
# Basis parameters |
nbasis = 7 # Number of fourier basis elements |
nknots = 20 |
# Registration parameters |
maxiter = 20 |
tol = 1e-5 |
# Plot options |
width = .8 # Width of the samples curves |
samples_color = 'teal' |
mean_color = 'black' |
curve_color = 'maroon' |
ylim = None |
iterations = 5 # Number of iterations in the step-by-step figure |
def noise_gaussian(t, nsamples, sd, shift_sd, amp_sd, error_sd): |
"""Noisy Gaussian curve function |
Args: |
t (ndarray): Array of times |
nsamples (float): Number of samples |
period (float): Period of the gaussian function sin(2*pi*t/period) |
shift_sd: Standard deviation of the shift variation normaly distributed |
amp_sd: Standard deviation of the amplitude variation |
error_sd: Standard deviation of the error of the samples |
Returns: |
ndarray with the samples evaluated. Each row is a sample and each |
column is a discrete time. |
""" |
shift_variation = np.outer(np.random.normal(0, shift_sd, nsamples), |
np.ones(len(t))) |
error = np.random.normal(0, error_sd, (nsamples, len(t))) |
amp = np.diag(np.random.normal(1, amp_sd, nsamples)) |
tsamples = t - shift_variation |
return amp @ stats.norm.pdf(tsamples, 0, sd) + error |
if __name__ == '__main__': |
# Matplotlib stylesheet |
plt.style.use('seaborn') |
# Fixing random state for reproducibility |
np.random.seed(1) |
# Matrix with times where each sample will be evaluated |
t = np.linspace(xlim[0], xlim[1], nfine) |
# Noisy gaussian data, with amplitude variation and gaussian error |
data = noise_gaussian(t, nsamples, sd, shift_sd, amp_sd, error_sd) |
# Real gaussian function |
gaussian = stats.norm.pdf(t, 0, sd) |
# Plot the samples |
plt.figure() |
plt.title('Raw data') |
#plt.ylim(ylim) |
plt.xlim(xlim) |
l1 = plt.plot(t, data.T, label='samples', c=samples_color, linewidth=width) |
l2 = plt.plot(t, gaussian, label='gaussian', c=curve_color, linestyle='dashed') |
l3 = plt.plot(t, np.mean(data.T, axis=1), label='mean', c=mean_color) |
plt.legend(handles=[l1[0], l3[0], l2[0]], loc=1) |
knots = np.cos (np.pi * np.arange(nknots + 1) / nknots) |
knots = np.linspace(-1,1,nknots) |
# Curves smoothed with the matrix penalty method |
fd = FDataBasis.from_data(data, t, BSpline(xlim, knots=knots)) |
unregmean = fd.mean() # Mean of unregistered curves |
# Plots the smoothed curves |
plt.figure() |
plt.title('Unregistered curves') |
#plt.ylim(ylim) |
plt.xlim(xlim) |
l1 = fd.plot(label='samples', c=samples_color, linewidth=width) |
l2 = plt.plot(t, gaussian, label='gaussian', c=curve_color, linestyle='dashed') |
l3 = unregmean.plot(label='mean', c=mean_color) |
plt.legend(handles=[l1[0], l3[0], l2[0]], loc=1) |
# Shift registered curves |
regbasis = shift_registration(fd, maxiter=maxiter,tol=tol) |
regmean = regbasis.mean() # Registered mean |
# Plots the registered curves |
plt.figure() |
plt.title('Registered curves') |
#plt.ylim(ylim) |
plt.xlim(xlim) |
l1 = regbasis.plot(label='samples', c=samples_color, |
linewidth=width) |
l2 = plt.plot(t, gaussian, label='gaussian', c=curve_color, linestyle='dashed') |
l3 = regmean.plot(label='mean', c=mean_color) |
plt.legend(handles=[l1[0], l3[0], l2[0]], loc=1) |
# Plots the process step by step |
f, axarr = plt.subplots(iterations+1, 1, sharex=True, sharey=True) |
axarr[0].title.set_text('Step by step registration') |
plt.xlim(xlim) |
fd.plot(ax=axarr[0], c=samples_color, linewidth=width) |
axarr[0].set_ylabel('Unregistered') |
for i in range(1, iterations+1): |
# tol=0 to realize all the iterations |
regfd = shift_registration(fd, maxiter=i, tol=0.) |
regfd.plot(ax=axarr[i], c=samples_color, linewidth=width) |
axarr[i].set_ylabel('%d iteration%s' % (i, '' if i == 1 else 's')) |
plt.show() |