Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Created March 27, 2023 11:18
Show Gist options
  • Save vankesteren/1926052bf1c5c18cc1569f2f03bc813b to your computer and use it in GitHub Desktop.
Save vankesteren/1926052bf1c5c18cc1569f2f03bc813b to your computer and use it in GitHub Desktop.
Kernel gram matrices with RcppArmadillo
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
double rbf_kern(const arma::vec xi, const arma::vec xj, const double& gamma) {
return exp(-gamma*norm(xi-xj));
}
// [[Rcpp::export]]
arma::mat rbf_gram_cpp(const arma::mat& X, const double& gamma) {
// nb: expect transposed matrix!!
int N = X.n_cols;
arma::mat K = arma::mat(N, N);
for (int i = 0; i < N; i++) {
for (int j = 0; j < i; j++) {
K(i,j) = K(j,i) = rbf_kern(X.col(i), X.col(j), gamma);
}
}
return K + arma::eye(N, N);
}
// [[Rcpp::export]]
double tps_kern(const arma::vec xi, const arma::vec xj) {
// thin-plate spline kernel
double r = norm(xi-xj);
return r*r*log(r);
}
// [[Rcpp::export]]
arma::mat tps_gram_cpp(const arma::mat& X) {
// nb: expect transposed matrix!!
int N = X.n_cols;
arma::mat K = arma::mat(N, N);
for (int i = 0; i < N; i++) {
for (int j = 0; j < i; j++) {
K(i,j) = K(j,i) = tps_kern(X.col(i), X.col(j));
}
}
return K + arma::eye(N, N);
}
// [[Rcpp::export]]
arma::mat tps_basis_cpp(const arma::mat& X, const arma::mat& C) {
// nb: expect transposed matrix!!
int Nx = X.n_cols;
int Nc = C.n_cols;
arma::mat K = arma::mat(Nx, Nc);
for (int i = 0; i < Nx; i++) {
for (int j = 0; j < Nc; j++) {
K(i,j) = tps_kern(X.col(i), C.col(j));
}
}
return K;
}
/*** R
kernel_gram <- function(X, gamma = 0.5, type = "rbf") {
switch(
type,
rbf = rbf_gram_cpp(t(X), gamma),
tps = tps_gram_cpp(t(X))
)
}
tps_basis <- function(X, C) {
tps_basis_cpp(t(X), t(C))
}
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment