Last active
March 15, 2021 17:15
-
-
Save phil8192/6f1804a6fc9fc3c0ed9218601dea4837 to your computer and use it in GitHub Desktop.
matrix multiplication with rotations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 4 elem vector | |
w <- c(10, 20, 30, 40) | |
# 4*3 matrix | |
# [,1] [,2] [,3] | |
# [1,] 1 2 3 | |
# [2,] 4 5 6 | |
# [3,] 7 8 9 | |
# [4,] 10 11 12 | |
x <- matrix(1:12, ncol=3, byrow=T) | |
# do vectorised matrix multiplation with only element-wise ops to get: | |
# > w %*% x | |
# [,1] [,2] [,3] | |
# [1,] 700 800 900 | |
# we can do left/right vector rotations: | |
# > rot(c(1,2,3,4), 1) | |
# [1] 2 3 4 1 | |
# > rot(c(1,2,3,4), 2) | |
# [1] 3 4 1 2 | |
rot <- function(x, n) c(tail(x, -n), head(x, n)) | |
# 1st, multiply the w vector element-wise with each of the columns in x: | |
m1 <- w*x[, 1]; m2 <- w*x[, 2]; m3 <- w*x[, 3] | |
# 2nd, take each vector obtained above and continuously sum rotations such that | |
# the first element gets added with all others. | |
# e.g., | |
# > c(1,2,3) + c(2,3,1) + c(3,1,2) | |
# [1] 6 6 6 | |
s1 <- m1 + rot(m1, 1) + rot(m1, 2) + rot(m1, 3) | |
s2 <- m2 + rot(m2, 1) + rot(m2, 2) + rot(m2, 3) | |
s3 <- m3 + rot(m3, 1) + rot(m3, 2) + rot(m3, 3) | |
# now, use a single element mask (matrix diagonal) to add element-wise to get | |
# final result: [ 700 800 900 0] | |
mask <- c(1, 0, 0, 0) | |
s1*mask + s2*rot(mask, -1) + s3*rot(mask, -2) | |
# note that if the length of the vector is a power of 2, can do this | |
# nice trick! | |
x <- 1:128 | |
sum(x) | |
# [1] 8256 | |
x <- x + rot(x, 1) | |
x <- x + rot(x, 2) | |
x <- x + rot(x, 4) | |
x <- x + rot(x, 8) | |
x <- x + rot(x, 16) | |
x <- x + rot(x, 32) | |
x <- x + rot(x, 64) | |
head(x, 1) | |
# [1] 8256 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment