Last active
January 6, 2022 20:32
-
-
Save jlmelville/9b4e5d076e719a7541881e8cbf58a895 to your computer and use it in GitHub Desktop.
The Kabsch algorithm in R for aligning one point set over another
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Kabsch Algorithm | |
#' | |
#' Aligns two sets of points via rotations and translations. | |
#' | |
#' Given two sets of points, with one specified as the reference set, | |
#' the other set will be rotated so that the RMSD between the two is minimized. | |
#' The format of the matrix is that there should be one row for each of | |
#' n observations, and the number of columns, d, specifies the dimensionality | |
#' of the points. The point sets must be of equal size and with the same | |
#' ordering, i.e. point one of the second matrix is mapped to point one of | |
#' the reference matrix, point two of the second matrix is mapped to point two | |
#' of the reference matrix, and so on. | |
#' | |
#' @param pm n x d matrix of points to align to to \code{qm}. | |
#' @param qm n x d matrix of reference points. | |
#' @return Matrix \code{pm} rotated and translated so that the ith point | |
#' is aligned to the ith point of \code{qm} in the least-squares sense. | |
#' @references | |
#' \url{https://en.wikipedia.org/wiki/Kabsch_algorithm} | |
kabsch <- function(pm, qm) { | |
pm_dims <- dim(pm) | |
if (!all(dim(qm) == pm_dims)) { | |
stop(call. = TRUE, "Point sets must have the same dimensions") | |
} | |
# The rotation matrix will have (ncol - 1) leading ones in the diagonal | |
diag_ones <- rep(1, pm_dims[2] - 1) | |
# center the points | |
pm <- scale(pm, center = TRUE, scale = FALSE) | |
qm <- scale(qm, center = TRUE, scale = FALSE) | |
am <- crossprod(pm, qm) | |
svd_res <- svd(am) | |
# use the sign of the determinant to ensure a right-hand coordinate system | |
d <- determinant(tcrossprod(svd_res$v, svd_res$u))$sign | |
dm <- diag(c(diag_ones, d)) | |
# rotation matrix | |
um <- svd_res$v %*% tcrossprod(dm, svd_res$u) | |
# Rotate and then translate to the original centroid location of qm | |
sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center")) | |
} |
I believe there's still a bug. According to the wiki, the rotation matrix um
is to rotate Pm
unto Qm
, so the last line should be
sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
here is an example(sorry it's a bit long):
library(ggplot2)
kabsch <- function(pm, qm) {
pm_dims <- dim(pm)
if (!all(dim(qm) == pm_dims)) {
stop(call. = TRUE, "Point sets must have the same dimensions")
}
# The rotation matrix will have (ncol - 1) leading ones in the diagonal
diag_ones <- rep(1, pm_dims[2] - 1)
# center the points
pm <- scale(pm, center = TRUE, scale = FALSE)
qm <- scale(qm, center = TRUE, scale = FALSE)
am <- crossprod(pm, qm)
svd_res <- svd(am)
# use the sign of the determinant to ensure a right-hand coordinate system
d <- determinant(tcrossprod(svd_res$v, svd_res$u))$sign
dm <- diag(c(diag_ones, d))
# rotation matrix
um <- svd_res$v %*% tcrossprod(dm, svd_res$u)
# Rotate and then translate to the original centroid location of pm
sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
}
pm2 <- data.frame(
x= c(2,1,1),
y= c(2,2,1)
)
qm2 <- data.frame(
x = c(0, 0, -1.5),
y = c(0, 1.5, 1.5)
)
ggplot(pm2,aes(x,y))+
geom_point(color = "red")+
geom_path(color="red")+
geom_point(data=qm2, aes(x=x,y=y),color = "blue")+
geom_path(data=qm2, aes(x=x,y=y),color = "blue")+
xlim(-3,3)+
ylim(-3,3)
pm2.t <- kabsch(pm2,qm2)
pm2.t <- as.data.frame(pm2.t)
names(pm2.t) <- c("x","y")
ggplot(qm2,aes(x,y))+
geom_point(color = "red")+
geom_path(color="red")+
geom_point(data=pm2.t, aes(x=x,y=y),color = "blue")+
geom_path(data=pm2.t, aes(x=x,y=y),color = "blue")+
xlim(-3,3)+
ylim(-3,3)
Created on 2018-08-21 by the reprex package (v0.2.0.9000).
Thank you @tcgriffith. I don't get alerted to comments, so apologies for not acknowledging this sooner.
This must be a personal record for the most errors in the least amount of code.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Minor change to use
crossprod
andtcrossprod
rather than%*%
andt()
directly.