Created
November 22, 2021 04:32
-
-
Save BeBeBerr/e1b25e9f93ceb8260bb0868f1110a11f to your computer and use it in GitHub Desktop.
Curve fitting using g2o and Ceres
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <opencv2/core/core.hpp> | |
#include <ceres/ceres.h> | |
using namespace std; | |
struct Curve_Fit { | |
Curve_Fit(double x, double y) : _x(x), _y(y) {} | |
template <typename T> | |
bool operator()(const T *const pa, const T *const pb, const T *const pc, T *residual) const { | |
T x = T(_x); | |
T a = *pa; | |
T b = *pb; | |
T c = *pc; | |
*residual = T(_y) - ceres::exp(a * x * x + b * x + c); | |
return true; | |
} | |
private: | |
const double _x, _y; | |
}; | |
int main() { | |
double ar = 1.0, br = 2.0, cr = 1.0; | |
double ae = 2.0, be = -1.0, ce = 5.0; | |
int N = 100; | |
double w_sigma = 1.0; | |
double inv_sigma = 1.0 / w_sigma; | |
cv::RNG rng; | |
vector<double> x_data, y_data; | |
for (int i = 0; i < N; i++) { | |
double x = i / 100.0; | |
x_data.push_back(x); | |
double y = exp(ar * x * x + br * x + cr) + rng.gaussian(w_sigma * w_sigma); | |
y_data.push_back(y); | |
cout << "[" << x << ", " << y << "], " << endl; | |
} | |
ceres::Problem problem; | |
for (int i = 0; i < N; i++) { | |
problem.AddResidualBlock( | |
new ceres::AutoDiffCostFunction<Curve_Fit, 1, 1, 1, 1>(new Curve_Fit(x_data[i], y_data[i])), | |
nullptr, | |
&ae, &be, &ce | |
); | |
} | |
ceres::Solver::Options options; | |
options.linear_solver_type = ceres::DENSE_NORMAL_CHOLESKY; | |
options.minimizer_progress_to_stdout = true; | |
ceres::Solver::Summary summary; | |
ceres::Solve(options, &problem, &summary); | |
cout << summary.BriefReport() << endl; | |
cout << ae << " " << be << " " << ce << endl; | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cmake version to be used | |
cmake_minimum_required( VERSION 3.0 ) | |
# message("system: ${CMAKE_LIBRARY_PATH}") | |
# project name | |
project(helloworld) | |
# target | |
add_executable( g2o_curve g2o_curvefit.cc ) | |
# external libs | |
find_package(Ceres REQUIRED) | |
find_package(OpenCV REQUIRED) | |
find_package(G2O REQUIRED) | |
target_include_directories(g2o_curve | |
PRIVATE | |
${CERES_INCLUDE_DIRS} | |
${OpenCV_INCLUDE_DIRS} | |
${G2O_INCLUDE_DIRS} | |
) | |
target_link_libraries(g2o_curve | |
PRIVATE | |
${CERES_LIBRARIES} | |
${OpenCV_LIBRARIES} | |
/opt/homebrew/Cellar/g2o/20201223/lib/libg2o_core.dylib | |
/opt/homebrew/Cellar/g2o/20201223/lib/libg2o_stuff.dylib | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <g2o/core/base_vertex.h> | |
#include <g2o/core/base_unary_edge.h> | |
#include <g2o/core/block_solver.h> | |
#include <g2o/core/optimization_algorithm_levenberg.h> | |
#include <g2o/core/optimization_algorithm_gauss_newton.h> | |
#include <g2o/core/optimization_algorithm_dogleg.h> | |
#include <g2o/solvers/dense/linear_solver_dense.h> | |
#include <Eigen/Core> | |
#include <opencv2/core/core.hpp> | |
#include <cmath> | |
using namespace std; | |
class CurveFittingVertex : public g2o::BaseVertex<3, Eigen::Vector3d> { | |
public: | |
EIGEN_MAKE_ALIGNED_OPERATOR_NEW | |
virtual void setToOriginImpl() override { | |
_estimate << 0, 0, 0; | |
} | |
virtual void oplusImpl(const double *update) override { | |
_estimate += Eigen::Vector3d(update); | |
} | |
virtual bool read(istream &in) {} | |
virtual bool write(ostream &out) const {} | |
}; | |
class CurveFittingEdge : public g2o::BaseUnaryEdge<1, double, CurveFittingVertex> { | |
public: | |
EIGEN_MAKE_ALIGNED_OPERATOR_NEW | |
CurveFittingEdge(double x): BaseUnaryEdge(), _x(x) {} | |
virtual void computeError() override { | |
const CurveFittingVertex *v = static_cast<const CurveFittingVertex *>(_vertices[0]); | |
const Eigen::Vector3d abc = v->estimate(); | |
_error(0, 0) = _measurement - std::exp(abc(0, 0) * _x * _x + abc(1, 0) * _x + abc(2, 0)); | |
} | |
virtual bool read(istream &in) {} | |
virtual bool write(ostream &out) const {} | |
public: | |
double _x; | |
}; | |
int main() { | |
double ar = 1.0, br = 2.0, cr = 1.0; | |
double ae = 2.0, be = -1.0, ce = 5.0; | |
int N = 100; | |
double w_sigma = 1.0; | |
double inv_sigma = 1.0 / w_sigma; | |
cv::RNG rng; | |
vector<double> x_data, y_data; | |
for (int i = 0; i < N; i++) { | |
double x = i / 100.0; | |
x_data.push_back(x); | |
double y = exp(ar * x * x + br * x + cr) + rng.gaussian(w_sigma * w_sigma); | |
y_data.push_back(y); | |
cout << "[" << x << ", " << y << "], " << endl; | |
} | |
typedef g2o::BlockSolver<g2o::BlockSolverTraits<3, 1>> BlockSolverType; | |
typedef g2o::LinearSolverDense<BlockSolverType::PoseMatrixType> LinearSolverType; | |
auto solver = new g2o::OptimizationAlgorithmGaussNewton( | |
g2o::make_unique<BlockSolverType>(g2o::make_unique<LinearSolverType>()) | |
); | |
g2o::SparseOptimizer optimizer; | |
optimizer.setAlgorithm(solver); | |
optimizer.setVerbose(true); | |
CurveFittingVertex *v = new CurveFittingVertex(); | |
v->setEstimate(Eigen::Vector3d(ae, be, ce)); | |
v->setId(0); | |
optimizer.addVertex(v); | |
for (int i=0; i<N; i++) { | |
CurveFittingEdge *edge = new CurveFittingEdge(x_data[i]); | |
edge->setId(i); | |
edge->setVertex(0, v); | |
edge->setMeasurement(y_data[i]); | |
edge->setInformation(Eigen::Matrix<double, 1, 1>::Identity() * 1 / (w_sigma * w_sigma)); | |
optimizer.addEdge(edge); | |
} | |
optimizer.initializeOptimization(); | |
optimizer.optimize(10); | |
Eigen::Vector3d abc_estimate = v->estimate(); | |
cout << abc_estimate.transpose() << endl; | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
data = [ | |
[0, 2.71828], | |
[0.01, 2.93161], | |
[0.02, 2.12942], | |
[0.03, 2.46037], | |
[0.04, 4.18814], | |
[0.05, 2.73368], | |
[0.06, 2.42751], | |
[0.07, 3.44729], | |
[0.08, 3.72543], | |
[0.09, 2.1358], | |
[0.1, 4.12333], | |
[0.11, 3.38199], | |
[0.12, 4.81164], | |
[0.13, 1.62582], | |
[0.14, 1.76862], | |
[0.15, 3.21555], | |
[0.16, 3.0922], | |
[0.17, 5.82752], | |
[0.18, 4.29855], | |
[0.19, 2.74081], | |
[0.2, 5.75724], | |
[0.21, 3.53729], | |
[0.22, 1.95514], | |
[0.23, 2.99195], | |
[0.24, 3.28739], | |
[0.25, 4.70749], | |
[0.26, 6.24365], | |
[0.27, 5.81645], | |
[0.28, 4.88402], | |
[0.29, 4.75991], | |
[0.3, 7.25246], | |
[0.31, 5.92933], | |
[0.32, 7.00306], | |
[0.33, 5.22286], | |
[0.34, 5.16179], | |
[0.35, 7.26191], | |
[0.36, 6.40545], | |
[0.37, 6.25549], | |
[0.38, 6.56094], | |
[0.39, 6.53523], | |
[0.4, 8.14891], | |
[0.41, 7.77616], | |
[0.42, 7.40141], | |
[0.43, 8.75638], | |
[0.44, 7.20606], | |
[0.45, 7.57795], | |
[0.46, 8.21564], | |
[0.47, 9.84032], | |
[0.48, 6.96725], | |
[0.49, 9.90619], | |
[0.5, 9.27125], | |
[0.51, 9.87567], | |
[0.52, 10.3412], | |
[0.53, 9.55315], | |
[0.54, 11.3635], | |
[0.55, 10.8815], | |
[0.56, 13.0648], | |
[0.57, 11.4756], | |
[0.58, 11.337], | |
[0.59, 13.2393], | |
[0.6, 13.5299], | |
[0.61, 14.0441], | |
[0.62, 13.31], | |
[0.63, 13.672], | |
[0.64, 14.8504], | |
[0.65, 14.2599], | |
[0.66, 14.7724], | |
[0.67, 17.4339], | |
[0.68, 17.4632], | |
[0.69, 17.7598], | |
[0.7, 16.8223], | |
[0.71, 19.9468], | |
[0.72, 20.5446], | |
[0.73, 21.3767], | |
[0.74, 20.1435], | |
[0.75, 20.3088], | |
[0.76, 23.2543], | |
[0.77, 23.4349], | |
[0.78, 22.8706], | |
[0.79, 24.094], | |
[0.8, 25.4183], | |
[0.81, 25.5237], | |
[0.82, 27.9738], | |
[0.83, 28.5861], | |
[0.84, 29.5703], | |
[0.85, 29.6744], | |
[0.86, 32.667], | |
[0.87, 34.2698], | |
[0.88, 33.5124], | |
[0.89, 36.1479], | |
[0.9, 39.2485], | |
[0.91, 40.988], | |
[0.92, 41.5716], | |
[0.93, 41.3686], | |
[0.94, 44.285], | |
[0.95, 42.8312], | |
[0.96, 47.7941], | |
[0.97, 48.5931], | |
[0.98, 51.8487], | |
[0.99, 51.0258], | |
] | |
data = np.array(data) | |
plt.scatter(data[:, 0], data[:, 1]) | |
x = np.linspace(0, 1, 100) | |
a = 0.890908 | |
b = 2.1719 | |
c = 0.943628 | |
y = np.exp(a * x * x + b * x + c) | |
plt.plot(x, y, 'r') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment