Skip to content

Instantly share code, notes, and snippets.

@bwiernik
Last active September 4, 2020 07:43
Show Gist options
  • Save bwiernik/722f000787cc746c1ee12cdb84c7722f to your computer and use it in GitHub Desktop.
Save bwiernik/722f000787cc746c1ee12cdb84c7722f to your computer and use it in GitHub Desktop.
Posterior predictive checks for lm() models
## Functions to draw samples from distributions implied by lm() models
##
## Arguments for simulate_model correspond to arguments in:
## a) stats::predict.lm();
## b) stats::simulate();
## c) nsamples is an alias for nsim;
## d) dist_type specifies the type of distribution to sample from
## 1) "ml"; from the maximum likelihood distribution;
## the distribution implied by estimated model parameters;
## (equivalent to stats::simulate());
## 2) "confidence"; from the confidence distribution for the fitted values (regression line);
## 3) "prediction"; from the (posterior) predictive distribution.
##
## Written by Brenton M. Wiernik
## Last updated 2020-08-20
## License GPL 3.0
##
## Please cite as:
## Wiernik, B. M. (2020, August 20).
## _Posterior predictive distribution for lm() models._
## [R functions] <URL where you retrieved this file>
simulate_model <- function(
object,
newdata,
nsamples = 1,
dist_type = c("ml", "confidence", "prediction"),
nsim = NULL,
seed = NULL,
scale = NULL,
df = Inf,
na.action = na.pass,
pred.var = res.var/weights,
weights = 1,
...
) {
UseMethod("simulate_model")
}
simulate_model.lm <- function(
object,
newdata,
nsamples = 1,
dist_type = c("ml", "confidence", "prediction"),
nsim = NULL,
seed = NULL,
scale = NULL,
df = Inf,
na.action = na.pass,
pred.var = res.var/weights,
weights = 1,
bayesplot = FALSE,
...
) {
if (!is.null(nsim)) {
nsamples <- nsim
}
dist_type <- match.arg(dist_type, c("ml", "confidence", "prediction"))
if (!exists(".Random.seed", envir = .GlobalEnv, inherits = FALSE))
runif(1)
if (is.null(seed)) {
RNGstate <- get(".Random.seed", envir = .GlobalEnv)
}
else {
R.seed <- get(".Random.seed", envir = .GlobalEnv)
set.seed(seed)
RNGstate <- structure(seed, kind = as.list(RNGkind()))
on.exit(assign(".Random.seed", R.seed, envir = .GlobalEnv))
}
ftd <- predict(object,
newdata,
se.fit = TRUE,
scale = scale,
df = df,
type = "response",
na.action = na.action,
pred.var = pred.var,
weights = weights,
...)
.sigma <- ftd$residual.scale
.mean <- ftd$fit
.sd <- switch(dist_type,
ml = .sigma,
confidence = ftd$se.fit,
prediction = sqrt(ftd$se.fit^2 + .sigma^2)
)
fam <- if (isGlm <- inherits(object, "glm")) {
object$family$family
} else {
"gaussian"
}
isMlm <- identical(fam, "gaussian") && is.matrix(ftd)
nm <- if (isMlm) {
dimnames(.mean)
} else {
names(.mean)
}
if (isMlm) {
stop("simulate_model() is not yet implemented for multivariate lm()")
}
n <- length(.mean)
ntot <- n * nsamples
val <- switch(fam,
gaussian = {
if (isMlm) {} else {
.mean + rnorm(ntot, sd = .sd)
}
},
# if (!is.null(object$family$simulate)) {
# object$family$simulate(object, nsamples)
# } else {
stop(gettextf("family '%s' not implemented", fam), domain = NA)
# }
)
if (isMlm) {} else if (!is.list(val)) {
dim(val) <- c(n, nsamples)
val <- as.data.frame(val)
} else {
class(val) <- "data.frame"
}
names(val) <- paste0("sim_", seq_len(nsamples))
if (!is.null(nm)) {
row.names(val) <- nm
}
attr(val, "seed") <- RNGstate
attr(val, "dist_type") <- dist_type
val
}
# Method to plot posterior predictive distributions using bayesplot::pp_check()
pp_check.lm <- function(object,
newdata,
nsamples = 10,
dist_type = c("ml", "confidence", "prediction"),
nsim = NULL,
seed = NULL,
scale = NULL,
df = Inf,
na.action = na.pass,
pred.var = res.var/weights,
weights = 1,
...) {
y <- model.extract(model.frame(object), "response")
if (!is.null(dim(y))) {
stop("pp_check() is not yet implemented for multivariate lm()")
}
yrep <- simulate_model(object = object,
newdata = newdata,
nsamples = nsamples,
dist_type = dist_type,
nsim = nsim,
seed = seed,
scale = scale,
df = df,
na.action = na.action,
pred.var = pred.var,
weights = weights,
...)
bayesplot::ppc_dens_overlay(
y = y,
yrep = t(as.matrix(yrep))
)
}
# Example
mod_cars <- lm(speed ~ dist, data = cars)
pp_check(mod_cars)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment