Alternative solution for the original gist for fitting plane/surface to a set of data points in Python with optional using of weights on data points.
Polynomial features are generated with PolynomialFeature form scikit-learn and must not be coded manually.
Additional information are in my answer to the original gist
Last active
June 16, 2023 08:34
-
-
Save jensleitloff/f8c253ca8fb68cfabfff5b0cf1353429 to your computer and use it in GitHub Desktop.
[Python] Fitting plane/surface to a set of data points with optional weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
from sklearn.preprocessing import PolynomialFeatures | |
from sklearn.linear_model import LinearRegression | |
from sklearn.pipeline import make_pipeline | |
import matplotlib.pyplot as plt | |
USE_WEIGHTS = True | |
# some 3-dim points | |
mean = np.array([0.0,0.0,0.0]) | |
cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]]) | |
data = np.random.multivariate_normal(mean, cov, 50) | |
if USE_WEIGHTS: | |
# weights can't be negative | |
w = np.abs(np.random.normal(loc=1, scale=1, size=50)) | |
else: | |
w = np.ones(shape=50) | |
# regular grid covering the domain of the data | |
X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5)) | |
XX = X.flatten() | |
YY = Y.flatten() | |
order = 2 # 1: linear, 2: quadratic | |
model = make_pipeline(PolynomialFeatures(degree=order), LinearRegression()) | |
model.fit(data[:, :2], data[:, -1], linearregression__sample_weight=w) | |
Z = model.predict(np.c_[XX, YY]).reshape(X.shape) | |
# plot points and fitted surface | |
ax = plt.figure().add_subplot(projection='3d') | |
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2) | |
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50) | |
plt.xlabel('X') | |
plt.ylabel('Y') | |
ax.set_zlabel('Z') | |
ax.axis('equal') | |
ax.axis('tight') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment