Created
March 27, 2023 11:18
-
-
Save vankesteren/1926052bf1c5c18cc1569f2f03bc813b to your computer and use it in GitHub Desktop.
Kernel gram matrices with RcppArmadillo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <RcppArmadillo.h> | |
// [[Rcpp::depends(RcppArmadillo)]] | |
// [[Rcpp::export]] | |
double rbf_kern(const arma::vec xi, const arma::vec xj, const double& gamma) { | |
return exp(-gamma*norm(xi-xj)); | |
} | |
// [[Rcpp::export]] | |
arma::mat rbf_gram_cpp(const arma::mat& X, const double& gamma) { | |
// nb: expect transposed matrix!! | |
int N = X.n_cols; | |
arma::mat K = arma::mat(N, N); | |
for (int i = 0; i < N; i++) { | |
for (int j = 0; j < i; j++) { | |
K(i,j) = K(j,i) = rbf_kern(X.col(i), X.col(j), gamma); | |
} | |
} | |
return K + arma::eye(N, N); | |
} | |
// [[Rcpp::export]] | |
double tps_kern(const arma::vec xi, const arma::vec xj) { | |
// thin-plate spline kernel | |
double r = norm(xi-xj); | |
return r*r*log(r); | |
} | |
// [[Rcpp::export]] | |
arma::mat tps_gram_cpp(const arma::mat& X) { | |
// nb: expect transposed matrix!! | |
int N = X.n_cols; | |
arma::mat K = arma::mat(N, N); | |
for (int i = 0; i < N; i++) { | |
for (int j = 0; j < i; j++) { | |
K(i,j) = K(j,i) = tps_kern(X.col(i), X.col(j)); | |
} | |
} | |
return K + arma::eye(N, N); | |
} | |
// [[Rcpp::export]] | |
arma::mat tps_basis_cpp(const arma::mat& X, const arma::mat& C) { | |
// nb: expect transposed matrix!! | |
int Nx = X.n_cols; | |
int Nc = C.n_cols; | |
arma::mat K = arma::mat(Nx, Nc); | |
for (int i = 0; i < Nx; i++) { | |
for (int j = 0; j < Nc; j++) { | |
K(i,j) = tps_kern(X.col(i), C.col(j)); | |
} | |
} | |
return K; | |
} | |
/*** R | |
kernel_gram <- function(X, gamma = 0.5, type = "rbf") { | |
switch( | |
type, | |
rbf = rbf_gram_cpp(t(X), gamma), | |
tps = tps_gram_cpp(t(X)) | |
) | |
} | |
tps_basis <- function(X, C) { | |
tps_basis_cpp(t(X), t(C)) | |
} | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment