Created
October 24, 2024 21:44
-
-
Save louisaslett/135aaea1eb066ea62598810376d1bdf7 to your computer and use it in GitHub Desktop.
#rstats matrix inversion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Script to compare matrix inversion speeds of base R, Matrix and torch packages | |
# Also shows double -vs- single precision torch speeds (see output.txt for results on M1 Mac) | |
library("Matrix") | |
library("torch") | |
base <- microbenchmark::microbenchmark( | |
b <- solve(a), | |
setup = { | |
a <- matrix(rnorm(1600*1600), 1600) | |
invisible(gc()) | |
} | |
) | |
Matrix_pkg <- microbenchmark::microbenchmark( | |
b <- solve(a), | |
setup = { | |
a <- new("dgeMatrix", x = rnorm(1600*1600), Dim = as.integer(c(1600, 1600))) | |
invisible(gc()) | |
} | |
) | |
torch_double_pkg <- microbenchmark::microbenchmark( | |
b <- linalg_inv(a), | |
setup = { | |
a <- torch_randn(c(1600, 1600), dtype = "double") | |
invisible(gc()) | |
} | |
) | |
torch_float_pkg <- microbenchmark::microbenchmark( | |
b <- linalg_inv(a), | |
setup = { | |
a <- torch_randn(c(1600, 1600), dtype = "float") | |
invisible(gc()) | |
} | |
) | |
# Print timings | |
base | |
Matrix_pkg | |
torch_double_pkg | |
torch_float_pkg | |
# Check accuracy for same matrix | |
base_a <- matrix(rnorm(1600*1600), 1600) | |
base_b <- solve(base_a) | |
Matrix_a <- as(as(as(base_a, "dMatrix"), "generalMatrix"), "unpackedMatrix") | |
Matrix_b <- solve(Matrix_a) | |
torch_double_a <- torch_tensor(base_a, dtype = "double") | |
torch_double_b <- linalg_inv(torch_double_a) | |
torch_float_a <- torch_tensor(base_a, dtype = "float") | |
torch_float_b <- linalg_inv(torch_float_a) | |
all.equal(base_a, as.matrix(Matrix_a)) | |
all.equal(base_b, as.matrix(Matrix_b)) | |
all.equal(base_a, as.matrix(torch_double_a)) | |
all.equal(base_b, as.matrix(torch_double_b)) | |
all.equal(base_a, as.matrix(torch_float_a)) | |
all.equal(base_b, as.matrix(torch_float_b)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Output observed on M1 Mac | |
# Note: Accelerate BLAS enabled for these runs | |
# (https://cran.r-project.org/bin/macosx/RMacOSX-FAQ.html#Which-BLAS-is-used-and-how-can-it-be-changed_003f) | |
# Take home: | |
# torch can get the gains of an accelerated BLAS without needing the BLAS | |
# and if you're willing to sacrifice a little accuracy by using floats then | |
# you can get a lot faster even than optimised BLAS | |
base | |
#> Unit: milliseconds | |
#> expr min lq mean median uq max neval | |
#> b <- solve(a) 142.5731 147.032 148.4055 147.9406 149.636 163.1964 100 | |
Matrix_pkg | |
#> Unit: milliseconds | |
#> expr min lq mean median uq max neval | |
#> b <- solve(a) 106.8548 108.7661 110.0113 109.8662 110.8723 119.1256 100 | |
torch_double_pkg | |
#> Unit: milliseconds | |
#> expr min lq mean median uq max neval | |
#> b <- linalg_inv(a) 103.1852 104.0808 104.7773 104.2915 104.8387 117.1828 100 | |
torch_float_pkg | |
#> Unit: milliseconds | |
#> expr min lq mean median uq max neval | |
#> b <- linalg_inv(a) 29.74267 29.99056 30.27734 30.09615 30.24777 40.33801 100 | |
all.equal(base_a, as.matrix(Matrix_a)) | |
#> [1] TRUE | |
all.equal(base_b, as.matrix(Matrix_b)) | |
#> [1] TRUE | |
all.equal(base_a, as.matrix(torch_double_a)) | |
#> [1] TRUE | |
all.equal(base_b, as.matrix(torch_double_b)) | |
#> [1] TRUE | |
all.equal(base_a, as.matrix(torch_float_a)) | |
#> [1] "Mean relative difference: 2.149769e-08" | |
all.equal(base_b, as.matrix(torch_float_b)) | |
#> [1] "Mean relative difference: 0.0003306122" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment