|
#!/usr/bin/evn python |
|
|
|
import numpy as np |
|
import scipy.linalg |
|
from mpl_toolkits.mplot3d import Axes3D |
|
import matplotlib.pyplot as plt |
|
|
|
# some 3-dim points |
|
mean = np.array([0.0,0.0,0.0]) |
|
cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]]) |
|
data = np.random.multivariate_normal(mean, cov, 50) |
|
|
|
# regular grid covering the domain of the data |
|
X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5)) |
|
XX = X.flatten() |
|
YY = Y.flatten() |
|
|
|
order = 1 # 1: linear, 2: quadratic |
|
if order == 1: |
|
# best-fit linear plane |
|
A = np.c_[data[:,0], data[:,1], np.ones(data.shape[0])] |
|
C,_,_,_ = scipy.linalg.lstsq(A, data[:,2]) # coefficients |
|
|
|
# evaluate it on grid |
|
Z = C[0]*X + C[1]*Y + C[2] |
|
|
|
# or expressed using matrix/vector product |
|
#Z = np.dot(np.c_[XX, YY, np.ones(XX.shape)], C).reshape(X.shape) |
|
|
|
elif order == 2: |
|
# best-fit quadratic curve |
|
A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2] |
|
C,_,_,_ = scipy.linalg.lstsq(A, data[:,2]) |
|
|
|
# evaluate it on a grid |
|
Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape) |
|
|
|
# plot points and fitted surface |
|
fig = plt.figure() |
|
ax = fig.gca(projection='3d') |
|
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2) |
|
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50) |
|
plt.xlabel('X') |
|
plt.ylabel('Y') |
|
ax.set_zlabel('Z') |
|
ax.axis('equal') |
|
ax.axis('tight') |
|
plt.show() |