Last active
September 22, 2021 01:36
-
-
Save fangzhou-xie/0918123ba17897399ceb0608b921f645 to your computer and use it in GitHub Desktop.
Eigen Sinkhorn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// write the sinkhorn algorithm and compare with the POT result | |
#include "Eigen/Eigen" | |
// #include "unsupported/Eigen/MatrixFunctions" | |
#include <iostream> | |
void sinkhorn(Eigen::VectorXd a, Eigen::VectorXd b, Eigen::MatrixXd M, | |
const double reg, const int numItermax = 1000, | |
const double stopThr = 1e-9) { | |
// only compute 1d to 1d case | |
// init u, v | |
Eigen::VectorXd u(a.rows()); | |
Eigen::VectorXd v(b.rows()); | |
// Eigen::MatrixXd K(M.rows(), M.cols()); | |
// Eigen::VectorXd uprev; | |
// Eigen::VectorXd vprev; | |
u.setOnes(); | |
v.setOnes(); | |
u = u / u.rows(); | |
v = v / b.rows(); | |
Eigen::MatrixXd K = (M / (-reg)).array().exp(); | |
Eigen::MatrixXd Kp = K.array().colwise() * a.array().cwiseInverse(); | |
unsigned int cpt = 0; | |
double err = 1.0; | |
Eigen::VectorXd temp1(v.rows()); | |
temp1.setOnes(); | |
while (err > stopThr & cpt < numItermax) { | |
Eigen::VectorXd uprev = u; | |
Eigen::VectorXd vprev = v; | |
Eigen::MatrixXd KtransposeU = K.transpose() * u; // n*1 | |
// std::cout << "shape of KtransposeU:" << KtransposeU.rows() << " " | |
// << KtransposeU.cols() << std::endl; | |
v = b.cwiseQuotient(K.transpose() * u); | |
// std::cout << "shape of v:" << v.rows() << " " << v.cols() << std::endl; | |
// std::cout << a.cwiseInverse() << std::endl; | |
// v = b / KtransposeU; | |
// v = KtransposeU.cwiseInverse() * b; // need broadcast here | |
// TODO: this is a broadcast multiply | |
// Eigen::MatrixXd Kp = K * a.cwiseInverse(); | |
// Kp.array().colwise() / | |
// Kp.colwise() *= a.cwiseInverse(); | |
// = a.cwiseInverse() * K.colwise(); | |
u = (Kp * v).cwiseInverse(); | |
if (u.array().isNaN().any() | v.array().isNaN().any() | | |
u.array().isInf().any() | v.array().isInf().any() | | |
(KtransposeU.array() == 0.0).any()) { | |
std::cout << "numerical error" << std::endl; | |
u = uprev; | |
v = vprev; | |
break; | |
} | |
// check error | |
if (cpt % 10 == 0) { | |
Eigen::VectorXd temp = | |
(u.asDiagonal() * K * v.asDiagonal()).transpose() * temp1; | |
err = (temp - b).norm(); | |
} | |
cpt += 1; | |
} | |
// u.resize(u.rows(), 1); | |
// v.resize(1, v.rows()); | |
Eigen::MatrixXd mat = | |
(K.array().colwise() * u.array()).rowwise() * v.array().transpose(); | |
// Eigen::MatrixXd mat = (u.array() * K.array().colwise()).array().rowwise() * | |
// v.array().transpose(); | |
std::cout << mat << std::endl; | |
// Eigen::MatrixXd mat = u.resize(u.rows(), 1) * K * v.resize(1, v.cols()); | |
} | |
int main() { | |
Eigen::VectorXd a(2); | |
a << 0.5, 0.5; | |
Eigen::VectorXd b(2); | |
b << 0.5, 0.5; | |
Eigen::MatrixXd M(2, 2); | |
M << 0.0, 1.0, 1.0, 0.0; | |
double reg = 1.0; | |
// sinkhorn(a, b, M, reg); | |
// try out broadcasting | |
// std::cout << M.array().colwise() * a.array() << std::endl; | |
// std::cout << a.array().transpose() * M.array().colwise() << std::endl; | |
// try the elementwise power | |
// Eigen::VectorXd c = a.array().pow(b.array()); | |
// std::cout << c << std::endl; | |
// Eigen::MatrixXd K = (M / (-reg)).array().exp(); | |
// std::cout << K << std::endl; | |
// std::cout << a << std::endl; | |
// a.resize(1, 2); | |
// std::cout << a << std::endl; | |
// Eigen::MatrixXd b = a.cwiseInverse(); | |
// std::cout << b << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment