|
import numpy as np |
|
from scipy.linalg import lstsq |
|
import matplotlib.pyplot as plt |
|
|
|
def generateData(n = 30): |
|
# similar to peaks() function in MATLAB |
|
g = np.linspace(-3.0, 3.0, n) |
|
X, Y = np.meshgrid(g, g) |
|
X, Y = X.reshape(-1,1), Y.reshape(-1,1) |
|
Z = 3 * (1 - X)**2 * np.exp(- X**2 - (Y+1)**2) \ |
|
- 10 * (X/5 - X**3 - Y**5) * np.exp(- X**2 - Y**2) \ |
|
- 1/3 * np.exp(- (X+1)**2 - Y**2) |
|
return X, Y, Z |
|
|
|
def exp2model(e): |
|
# C[i] * X^n * Y^m |
|
return ' + '.join([ |
|
f'C[{i}]' + |
|
('*' if x>0 or y>0 else '') + |
|
(f'X^{x}' if x>1 else 'X' if x==1 else '') + |
|
('*' if x>0 and y>0 else '') + |
|
(f'Y^{y}' if y>1 else 'Y' if y==1 else '') |
|
for i,(x,y) in enumerate(e) |
|
]) |
|
|
|
# generate some random 3-dim points |
|
X, Y, Z = generateData() |
|
|
|
# 1=linear, 2=quadratic, 3=cubic, ..., nth degree |
|
order = 11 |
|
|
|
# calculate exponents of design matrix |
|
#e = [(x,y) for x in range(0,order+1) for y in range(0,order-x+1)] |
|
e = [(x,y) for n in range(0,order+1) for y in range(0,n+1) for x in range(0,n+1) if x+y==n] |
|
eX = np.asarray([[x] for x,_ in e]).T |
|
eY = np.asarray([[y] for _,y in e]).T |
|
|
|
# best-fit polynomial surface |
|
A = (X ** eX) * (Y ** eY) |
|
C,resid,_,_ = lstsq(A, Z) # coefficients |
|
|
|
# calculate R-squared from residual error |
|
r2 = 1 - resid[0] / (Z.size * Z.var()) |
|
|
|
# print summary |
|
print(f'data = {Z.size}x3') |
|
print(f'model = {exp2model(e)}') |
|
print(f'coefficients =\n{C}') |
|
print(f'R2 = {r2}') |
|
|
|
# uniform grid covering the domain of the data |
|
XX,YY = np.meshgrid(np.linspace(X.min(), X.max(), 20), np.linspace(Y.min(), Y.max(), 20)) |
|
|
|
# evaluate model on grid |
|
A = (XX.reshape(-1,1) ** eX) * (YY.reshape(-1,1) ** eY) |
|
ZZ = np.dot(A, C).reshape(XX.shape) |
|
|
|
# plot points and fitted surface |
|
ax = plt.figure().add_subplot(projection='3d') |
|
ax.scatter(X, Y, Z, c='r', s=2) |
|
ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, alpha=0.2, linewidth=0.5, edgecolor='b') |
|
ax.axis('tight') |
|
ax.view_init(azim=-60.0, elev=30.0) |
|
ax.set_xlabel('X') |
|
ax.set_ylabel('Y') |
|
ax.set_zlabel('Z') |
|
plt.show() |
Would it be possible to use this to detrend my data, and how would I approach this?