Last active
January 16, 2019 23:07
-
-
Save dantonnoriega/5cda5ce9c048893dbb98985932e4e87c to your computer and use it in GitHub Desktop.
translate the stan code from https://arxiv.org/pdf/1808.06399.pdf into greta code (https://greta-dev.github.io/greta/index.html) but use new imultilogit() function over simplex_mat() --- about 15% speed increase.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# recreate the code from https://arxiv.org/pdf/1808.06399.pdf using greta (https://greta-dev.github.io/greta/) | |
library("DirichletReg") | |
Bld <- BloodSamples | |
Bld <- na.omit(Bld) | |
Bld$Smp <- DR_data(Bld[, 1:4]) | |
# using greta | |
# !! requires the development version to run! | |
# devtools::install_github("greta-dev/greta@dev") | |
# convert data to matrix then greta data | |
Y <- matrix(Bld$Smp, ncol = ncol(Bld$Smp)) | |
Y <- greta::as_data(Y) | |
X <- as.matrix(model.matrix(lm(Albumin ~ Disease, data = Bld))) | |
X <- matrix(nrow = nrow(X), ncol = ncol(X), data = as.numeric(X)) | |
sd_prior = 1 | |
theta <- greta::normal(0, sd_prior) # theta ~ normal(0, sd_prior); (scaling constant) | |
# greta lets you build data arrays then fill in with variable/operation values but not the other direction | |
# exclude last column (used to normalize) | |
beta <- greta::normal(0, sd_prior, dim = c(ncol(X), ncol(Y) - 1)) | |
alpha <- greta::`%*%`(X, beta) | |
alpha <- greta::imultilogit(alpha) # produces a simplex and adds normalizing vector | |
greta::distribution(Y) = greta::dirichlet(alpha*exp(theta)) # simplex via matrix operations | |
m <- greta::model(beta) | |
draws <- greta::mcmc(m, chains = 4, n_samples = 2000, warmup = 1000) | |
# split up draws -------------- | |
# greta outputs a single matrix of params | |
# need to split them up accordingly | |
P <- do.call(rbind, draws) # stack all chains | |
nms <- colnames(P) # get row names | |
# row 1 of beta estimates (untransformed) | |
B1 <- P %>% .[, grepl('1,.', nms)] # extract names with index [1,j] | |
# row 2 of beta estimates (untransformed) | |
B2 <- P %>% .[, grepl('2,.', nms)] # extract names with index [2,j | |
# plots -------------- | |
my_colors <- scales::hue_pal()(4) | |
# simplex function. applies simplex to each row of a matrix. | |
simplex_mat <- function(x){ | |
exp(x) / greta::`%*%`(exp(x), matrix(1, ncol(x), ncol(x))) | |
} | |
# disease A | |
aux <- simplex_mat(cbind(B1,0)) # add zero at the end to normalize. ORDER MATTER (last column is the one normalized). | |
layout(matrix(1:2, ncol = 2)) | |
plot(1:4, Bld[1, 1:4], ylim = c(0, 0.6), type = "n", xaxt = "n", las = 1, | |
xlab = "", ylab = "Proportion", main = "Disease A", xlim = c(0.6, 4.4)) | |
abline(h = seq(0, 0.6, by = 0.1), col = "grey", lty = 3) | |
axis(1, at = 1:4, labels = names(Bld)[1:4], las = 2) | |
apply(subset(Bld, Disease == "A")[, 1:4], MAR = 1, FUN = points, pch = 16, | |
col = "grey") | |
lines(apply(subset(Bld, Disease == "A")[, 1:4], MAR = 2, FUN = mean), | |
type = "b", pch = 16, cex = 1.2, lwd = 2) | |
lines(apply(aux, MAR = 2, FUN = quantile, prob = 0.975), type = "b", pch = 4, | |
lty = 2, col = my_colors[1]) | |
lines(apply(aux, MAR = 2, FUN = quantile, prob = 0.025), type = "b", pch = 4, | |
lty = 2, col = my_colors[1]) | |
lines(apply(aux, MAR = 2, FUN = mean), lwd = 2, col = my_colors[1], type = "b", | |
pch = 16) | |
plot(1:4, Bld[1, 1:4], ylim = c(0, 0.6), type = "n", xaxt = "n", las = 1, | |
xlab = "", ylab = "Proportion", main = "Disease B", xlim = c(0.6, 4.4)) | |
abline(h = seq(0, 0.6, by = 0.1), col = "grey", lty = 3) | |
axis(1, at = 1:4, labels = names(Bld)[1:4], las = 2) | |
# disease B estimates | |
aux <- simplex_mat(cbind(B1 + B2, 0)) # add zero at end to normalize. ORDER MATTERS. | |
apply(subset(Bld, Disease == "B")[, 1:4], MAR = 1, FUN = points, pch = 16, | |
col = "grey") | |
lines(apply(subset(Bld, Disease == "B")[, 1:4], MAR = 2, FUN = mean), | |
type = "b", pch = 16, cex = 1.2, lwd = 2) | |
lines(apply(aux, MAR = 2, FUN = quantile, prob = 0.975), type = "b", pch = 4, | |
lty = 2, col = my_colors[2]) | |
lines(apply(aux, MAR = 2, FUN = quantile, prob = 0.025), type = "b", pch = 4, | |
lty = 2, col = my_colors[2]) | |
lines(apply(aux, MAR = 2, FUN = mean), lwd = 2, col = my_colors[2], type = "b", | |
pch = 16) | |
layout(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment