Last active
April 2, 2024 10:11
-
-
Save jgillis/e5eb8dd1450417cb308ae440ee69dfb4 to your computer and use it in GitHub Desktop.
bspline fitting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from casadi import * | |
# Generate 3D data to fit | |
x = np.linspace(0,2,10) | |
y = np.linspace(2,4,10) | |
z = np.linspace(3,5,10) | |
[X,Y,Z] = np.meshgrid(x,y,z) | |
xyz_flat = np.vstack((X.ravel(),Y.ravel(),Z.ravel())).T | |
D = (X+Y)*Z | |
# Define type of BSpline fit | |
degree = [1,2,3] | |
n_dims = len(degree) | |
knots_precursor = [[0,0.5,1,1.5,2], [2,3,4], [3,4,5]] | |
# Now with multiplicities | |
knots = [([p[0]]*d)+p+([p[-1]]*d) for p,d in zip(knots_precursor,degree)] | |
# The most efficient way to fit a large amount of data to a bspline | |
# First, figure out the sparse numeric mapping between coefficients and outputs | |
J = MX.bspline_dual(xyz_flat.ravel(),knots,degree) | |
# Construct and solve a fitting problem to figure out the coefficients | |
opti = Opti() | |
coeff = opti.variable(J.size2()) | |
opti.minimize(sumsqr(J @ coeff - D.ravel())) | |
# Note: if you want to incorporate spline derivatives into the object/constraints, | |
# there is manual work needed | |
opti.solver("ipopt") | |
sol = opti.solve() | |
coeff_fitted = sol.value(coeff) | |
# Numeric fitting is finished | |
# Now construct a BSpline CasADi Function | |
x = MX.sym("x",n_dims) | |
y = bspline(x,DM(coeff_fitted),knots,degree,1) | |
f = Function('spline',[x],[y]) | |
# Test | |
print(f(vertcat(0,2,3)),(0+2)*3) | |
print(f(vertcat(0.17,2.8,3.5)),(0.17+2.8)*3.5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from casadi import * | |
# Generate 3D data to fit | |
x = np.linspace(0,2,10) | |
y = np.linspace(2,4,10) | |
z = np.linspace(3,5,10) | |
[X,Y,Z] = np.meshgrid(x,y,z) | |
xyz_flat = np.vstack((X.ravel(),Y.ravel(),Z.ravel())).T | |
D = (X+Y)*Z | |
# Define type of BSpline fit | |
degree = [1,2,3] | |
n_dims = len(degree) | |
knots_precursor = [[0,0.5,1,1.5,2], [2,3,4], [3,4,5]] | |
# Now with multiplicities | |
knots = [([p[0]]*d)+p+([p[-1]]*d) for p,d in zip(knots_precursor,degree)] | |
# Construct a differentiable BSpline CasADi Function | |
x = MX.sym("x",n_dims) | |
C = MX.sym("C",100) | |
y = bspline(x,C,knots,degree,1,{"inline":True}) | |
f = Function("spline",[x,C],[y]) | |
# Fit in opti | |
opti = Opti() | |
coeff = opti.variable(100) | |
opti.minimize(sumsqr(f(xyz_flat.T,coeff).T-D.ravel())) | |
# Note: if you want to incorporate spline derivatives into the object/constraints, | |
# you can simply use CasADi AD on f | |
opti.solver("ipopt") | |
sol = opti.solve() | |
coeff_fitted = sol.value(coeff) | |
# Numeric fitting is finished | |
# Now construct a BSpline CasADi Function | |
x = MX.sym("x",n_dims) | |
y = bspline(x,DM(coeff_fitted),knots,degree,1) | |
f = Function('spline',[x],[y]) | |
# Test | |
print(f(vertcat(0,2,3)),(0+2)*3) | |
print(f(vertcat(0.17,2.8,3.5)),(0.17+2.8)*3.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment