|
import numpy as np |
|
from scipy.linalg import lstsq |
|
import matplotlib.pyplot as plt |
|
|
|
def generateData(n = 30): |
|
# similar to peaks() function in MATLAB |
|
g = np.linspace(-3.0, 3.0, n) |
|
X, Y = np.meshgrid(g, g) |
|
X, Y = X.reshape(-1,1), Y.reshape(-1,1) |
|
Z = 3 * (1 - X)**2 * np.exp(- X**2 - (Y+1)**2) \ |
|
- 10 * (X/5 - X**3 - Y**5) * np.exp(- X**2 - Y**2) \ |
|
- 1/3 * np.exp(- (X+1)**2 - Y**2) |
|
return X, Y, Z |
|
|
|
def exp2model(e): |
|
# C[i] * X^n * Y^m |
|
return ' + '.join([ |
|
f'C[{i}]' + |
|
('*' if x>0 or y>0 else '') + |
|
(f'X^{x}' if x>1 else 'X' if x==1 else '') + |
|
('*' if x>0 and y>0 else '') + |
|
(f'Y^{y}' if y>1 else 'Y' if y==1 else '') |
|
for i,(x,y) in enumerate(e) |
|
]) |
|
|
|
# generate some random 3-dim points |
|
X, Y, Z = generateData() |
|
|
|
# 1=linear, 2=quadratic, 3=cubic, ..., nth degree |
|
order = 11 |
|
|
|
# calculate exponents of design matrix |
|
#e = [(x,y) for x in range(0,order+1) for y in range(0,order-x+1)] |
|
e = [(x,y) for n in range(0,order+1) for y in range(0,n+1) for x in range(0,n+1) if x+y==n] |
|
eX = np.asarray([[x] for x,_ in e]).T |
|
eY = np.asarray([[y] for _,y in e]).T |
|
|
|
# best-fit polynomial surface |
|
A = (X ** eX) * (Y ** eY) |
|
C,resid,_,_ = lstsq(A, Z) # coefficients |
|
|
|
# calculate R-squared from residual error |
|
r2 = 1 - resid[0] / (Z.size * Z.var()) |
|
|
|
# print summary |
|
print(f'data = {Z.size}x3') |
|
print(f'model = {exp2model(e)}') |
|
print(f'coefficients =\n{C}') |
|
print(f'R2 = {r2}') |
|
|
|
# uniform grid covering the domain of the data |
|
XX,YY = np.meshgrid(np.linspace(X.min(), X.max(), 20), np.linspace(Y.min(), Y.max(), 20)) |
|
|
|
# evaluate model on grid |
|
A = (XX.reshape(-1,1) ** eX) * (YY.reshape(-1,1) ** eY) |
|
ZZ = np.dot(A, C).reshape(XX.shape) |
|
|
|
# plot points and fitted surface |
|
ax = plt.figure().add_subplot(projection='3d') |
|
ax.scatter(X, Y, Z, c='r', s=2) |
|
ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, alpha=0.2, linewidth=0.5, edgecolor='b') |
|
ax.axis('tight') |
|
ax.view_init(azim=-60.0, elev=30.0) |
|
ax.set_xlabel('X') |
|
ax.set_ylabel('Y') |
|
ax.set_zlabel('Z') |
|
plt.show() |
Thanks MSchmidt99!
That's good. I added a bit, which users can try this code directly. In the follow:
import numpy as np
from scipy.linalg import lstsq
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def generateData(n = 30):
# similar to peaks() function in MATLAB
g = np.linspace(-3.0, 3.0, n)
X, Y = np.meshgrid(g, g)
X, Y = X.reshape(-1,1), Y.reshape(-1,1)
Z = 3 * (1 - X)2 * np.exp(- X2 - (Y+1)2)
- 10 * (X/5 - X3 - Y5) * np.exp(- X2 - Y**2)
- 1/3 * np.exp(- (X+1)2 - Y2)
data_lt = []
for ee in range(len(X)):
data_lt.append([float(X[ee]),float(Y[ee]),float(Z[ee])])
return np.array(data_lt)
in: [meshgrid], [meshgrid], [np list of coordinate pair np lists. ex: [[x1,y1,z1], [x2,y2,z2], etc.] ], [degree]
out: [Z]
def curve(X, Y, coord, n):
XX = X.flatten()
YY = Y.flatten()
in: [array], [array], [int]
out: sum from k=0 to k=n of n choose k for x^n-k * y^k (coefficients ignored)
def XYchooseN(x,y,n):
XYchooseN = []
n = n+1
for j in range(len(x)):
I = x[j]
J = y[j]
matrix = []
Is = []
Js = []
for i in range(0,n):
Is.append(Ii)
Js.append(Ji)
matrix.append(np.concatenate((np.ones(n-i),np.zeros(i))))
Is = np.array(Is)
Js = np.array(Js)[np.newaxis]
IsJs0s = matrix * Is * Js.T
IsJs = []
for i in range(0,n):
IsJs = np.concatenate((IsJs,IsJs0s[i,:n-i]))
XYchooseN.append(IsJs)
return np.array(XYchooseN)
data = generateData()
generate some random 3-dim points
x_data, y_data, z_data = data[:,0],data[:,1],data[:,2]
X,Y = np.meshgrid(x_data,y_data) # surface meshgrid
coord = data
Z = curve(X, Y, coord, 3)
todo: cmap by avg from (x1,y1) to (x2,y2) of |Z height - scatter height|
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') # Create a 3D axis
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.5, color='r')
ax.scatter(data[:,0], data[:,1], data[:,2], c='green', s=50)
plt.xlabel('X')
plt.ylabel('Y')
ax.set_zlabel('Z')
ax.axis('equal')
ax.axis('tight')
plt.show()