Last active
October 22, 2015 03:22
-
-
Save benjamin-chan/9be70cc905872768712c to your computer and use it in GitHub Desktop.
Example of parallelization of glm()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if (!require(devtools)) {install.packages("devtools")} | |
library(devtools) | |
source_gist("https://gist.github.com/benjamin-chan/3b59313e8347fffea425") | |
loadPkg("doParallel") | |
loadPkg("data.table") | |
J <- 30 # This is the number of models to fit | |
N <- 2E5 # This is the size of the dataset | |
i <- rep(1:N, each=J) | |
D <- data.table(i, # id | |
j = rep(1:J, N), # index repitition | |
x1 = rep(rbinom(N, 1, 0.5), each=J), # group membership, adult/child | |
x2 = rnorm(N * J), # fake risk factor | |
x3 = rnorm(N * J), # fake risk factor | |
x4 = rbinom(N * J, 1, 0.5), # fake risk factor | |
x5 = rbinom(N * J, 1, 0.5)) # fake risk factor | |
D <- D[, logitp := -5 + x1 + x2 + x3 + x4 + x5] | |
D <- D[, p := exp(logitp) / (1 + exp(logitp))] | |
D <- D[, y := rbinom(N * J, 1, p)] | |
D <- rbind(D[, k := j], | |
D[x1 == 0, ][, k := as.integer(J * (x1 + 1) + j)], | |
D[x1 == 1, ][, k := as.integer(J * (x1 + 1) + j)]) | |
setkey(D, k, i, j) | |
D[, .N, .(k, j)][order(k, j)] | |
rm(N, i) | |
phat <- function (D) { | |
require(data.table) | |
fx <- formula(y ~ x1 + x2 + x3 + x4 + x5) | |
M <- glm(fx, data=D, family=binomial) | |
cbind(D[, .(k, i, j)], | |
phat = fitted(M)) | |
} | |
cores <- detectCores() - 1 # Use 1 less CPU than you have | |
cl <- makeCluster(cores) | |
registerDoParallel(cl) | |
ptime <- system.time( | |
phatP <- foreach(K = 1:(J * 3)) %dopar% { | |
phat(D[k == K, ]) | |
} | |
) | |
stopCluster(cl) | |
stime <- system.time( | |
phatS <- foreach(K = 1:(J * 3)) %do% { | |
phat(D[k == K, ]) | |
} | |
) | |
isEqual <- all.equal(phatP, phatS) | |
lines <- paste("Parallel (%d core%s): %.3g min (elapsed), %.3g min (user)", | |
"Serial: %.3g min (elapsed), %.3g min (user)", | |
"Relative difference: %.3g%% %s (elapsed), %.3g%% %s (user)", | |
"Equal phats? %s", | |
sep="\n") | |
message(sprintf(lines, | |
cores, | |
ifelse(cores > 1, "s", ""), | |
ptime[3] / 60, | |
ptime[1] / 60, | |
stime[3] / 60, | |
stime[1] / 60, | |
abs(ptime[3] - stime[3]) / stime[3] * 100, | |
ifelse(sign(ptime[3] - stime[3]) < 0, | |
"faster... Yay!", | |
"slower... Boo!"), | |
abs(ptime[1] - stime[1]) / stime[1] * 100, | |
ifelse(sign(ptime[1] - stime[1]) < 0, | |
"faster... Yay!", | |
"slower... Boo!"), | |
isEqual)) | |
ps <- rbindlist(phatP) | |
setkey(ps, k, i, j) | |
D <- merge(D, ps) | |
merge(D[, .(Ep = mean(p), Ephat = mean(phat)), .(x1, y)], | |
D[, .N, .(x1, y)], | |
by=c("x1", "y")) | |
cor(D[, .(p, phat)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment