Skip to content

Instantly share code, notes, and snippets.

@alexlib
Forked from amroamroamro/README.md
Created August 27, 2020 13:58
Show Gist options
  • Select an option

  • Save alexlib/05abf4e42a0341047b1da12e69c2f3a3 to your computer and use it in GitHub Desktop.

Select an option

Save alexlib/05abf4e42a0341047b1da12e69c2f3a3 to your computer and use it in GitHub Desktop.
[Python] Fitting plane/surface to a set of data points

Python version of the MATLAB code in this Stack Overflow post: http://stackoverflow.com/a/18648210/97160

The example shows how to determine the best-fit plane/surface (1st or higher order polynomial) over a set of three-dimensional points.

Implemented in Python + NumPy + SciPy + matplotlib.

quadratic_surface

#!/usr/bin/evn python
import numpy as np
import scipy.linalg
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# some 3-dim points
mean = np.array([0.0,0.0,0.0])
cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]])
data = np.random.multivariate_normal(mean, cov, 50)
# regular grid covering the domain of the data
X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5))
XX = X.flatten()
YY = Y.flatten()
order = 1 # 1: linear, 2: quadratic
if order == 1:
# best-fit linear plane
A = np.c_[data[:,0], data[:,1], np.ones(data.shape[0])]
C,_,_,_ = scipy.linalg.lstsq(A, data[:,2]) # coefficients
# evaluate it on grid
Z = C[0]*X + C[1]*Y + C[2]
# or expressed using matrix/vector product
#Z = np.dot(np.c_[XX, YY, np.ones(XX.shape)], C).reshape(X.shape)
elif order == 2:
# best-fit quadratic curve
A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2]
C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])
# evaluate it on a grid
Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape)
# plot points and fitted surface
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2)
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50)
plt.xlabel('X')
plt.ylabel('Y')
ax.set_zlabel('Z')
ax.axis('equal')
ax.axis('tight')
plt.show()
@alexlib
Copy link
Copy Markdown
Author

alexlib commented Aug 27, 2020

#!/usr/bin/evn python

import numpy as np
import scipy.linalg
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

# some 3-dim points
mean = np.array([0.0,0.0,0.0])
cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]])
data = np.random.multivariate_normal(mean, cov, 50)

# regular grid covering the domain of the data
X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5))
XX = X.flatten()
YY = Y.flatten()

order = 3    # 1: linear, 2: quadratic
if order == 1:
    # best-fit linear plane
    A = np.c_[data[:,0], data[:,1], np.ones(data.shape[0])]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])    # coefficients
    
    # evaluate it on grid
    # Z = C[0]*X + C[1]*Y + C[2]
    
    # or expressed using matrix/vector product
    Z = np.dot(np.c_[XX, YY, np.ones(XX.shape)], C).reshape(X.shape)

elif order == 2:
    # best-fit quadratic curve
    # M = [ones(size(x)), x, y, x.*y, x.^2 y.^2]
    A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])
    
    # evaluate it on a grid
    Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape)
    
elif order == 3:
    # M = [ones(size(x)), x, y, x.^2, x.*y, y.^2, x.^3, x.^2.*y, x.*y.^2, y.^3]
    A = np.c_[np.ones(data.shape[0]), data[:,:2], data[:,0]**2, np.prod(data[:,:2], axis=1), \
              data[:,1]**2, data[:,0]**3, np.prod(np.c_[data[:,0]**2,data[:,1]],axis=1), \
              np.prod(np.c_[data[:,0],data[:,1]**2],axis=1), data[:,1]**3]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])
    
    Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX**2, XX*YY, YY**2, XX**3, XX**2*YY, XX*YY**2, YY**3], C).reshape(X.shape)

# plot points and fitted surface
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2)
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50)
plt.xlabel('X')
plt.ylabel('Y')
ax.set_zlabel('Z')
# ax.axis('equal')
ax.axis('tight')
plt.show()

@ajd1212
Copy link
Copy Markdown

ajd1212 commented Dec 13, 2020

elif order == 3:
A = np.c_[np.ones(data.shape[0]), data[:,:2], data[:,0]**2, np.prod(data[:,:2], axis=1),
data[:,1]**2, data[:,0]**3, np.prod(np.c_[data[:,0]**2,data[:,1]],axis=1),
np.prod(np.c_[data[:,0],data[:,1]**2],axis=1), data[:,2]**3]

I think the last term should be data[:,1]**3].

@alexlib
Copy link
Copy Markdown
Author

alexlib commented Jan 26, 2022

Thanks. fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment