Created
October 3, 2019 08:18
-
-
Save ryanholbrook/b5c7d44c0c7642eeee1a3034b48f29d7 to your computer and use it in GitHub Desktop.
Convex Dirichlet Aggregation with brms and Parsnip
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(brms) | |
library(parsnip) | |
## Fits a brms model | |
convex_regression <- function(formula, data, | |
family = "gaussian", | |
alpha = 1, gamma = 2, # Yang (2014) recommends alpha=1, gamma=2 | |
verbose = 0, | |
...) { | |
if (gamma <= 1) { | |
warning(paste("Parameter gamma should be greater than 1. Given:", gamma)) | |
} | |
if (alpha <= 0) { | |
warning(paste("Parameter alpha should be greater than 0. Given:", alpha)) | |
} | |
## Set up priors. | |
K <- length(terms(formula)) | |
alpha_K <- alpha / (K^gamma) | |
stanvars <- | |
stanvar(alpha_K, | |
"alpha_K", | |
block = "data", | |
scode = " real<lower = 0> alpha_K; // dirichlet parameter" | |
) + | |
stanvar( | |
name = "b_raw", | |
block = "parameters", | |
scode = " vector<lower = 0>[K] b_raw; " | |
) + | |
stanvar( | |
name = "b", | |
block = "tparameters", | |
scode = " vector[K] b = b_raw / sum(b_raw);" | |
) | |
prior <- prior("target += gamma_lpdf(b_raw | alpha_K, 1)", | |
class = "b_raw", check = FALSE | |
) | |
f <- update.formula(formula, . ~ . - 1) | |
if (verbose > 0) { | |
make_stancode(f, | |
prior = prior, | |
data = data, | |
stanvars = stanvars | |
) %>% message() | |
} | |
fit_dir <- brm(f, | |
prior = prior, | |
family = family, | |
data = data, | |
stanvars = stanvars, | |
... | |
) | |
fit_dir | |
} | |
## Parsnip Definition | |
set_new_model("convex_reg") | |
set_model_mode(model = "convex_reg", mode = "regression") | |
set_model_engine( | |
"convex_reg", | |
mode = "regression", | |
eng = "brms" | |
) | |
set_dependency("convex_reg", "brms", "brms") | |
set_model_arg( | |
model = "convex_reg", | |
eng = "brms", | |
parsnip = "scale", | |
original = "alpha", | |
func = list(fun = "alpha"), | |
has_submodel = FALSE | |
) | |
alpha <- new_quant_param( | |
type = "double", | |
range = c(1, Inf), | |
inclusive = c(TRUE, FALSE), | |
default = 1, | |
label = c(scale = "scale") | |
) | |
set_model_arg( | |
model = "convex_reg", | |
eng = "brms", | |
parsnip = "penalty", | |
original = "gamma", | |
func = list(fun = "gamma"), | |
has_submodel = FALSE | |
) | |
gamma <- new_quant_param( | |
type = "double", | |
range = c(1, Inf), | |
inclusive = c(TRUE, FALSE), | |
default = 2, | |
label = c(penalty = "penalty") | |
) | |
convex_reg <- function(mode = "regression", scale = 1, penalty = 2) { | |
## Check for correct mode | |
if (mode != "regression") { | |
stop("`mode` should be 'regression'", call. = FALSE) | |
} | |
## Capture the arguments in quosures | |
args <- list( | |
scale = rlang::enquo(scale), | |
penalty = rlang::enquo(penalty) | |
) | |
## Save some empty slots for future parts of the specification | |
out <- list( | |
args = args, eng_args = NULL, | |
mode = mode, method = NULL, engine = NULL | |
) | |
## set classes in the correct order | |
class(out) <- make_classes("convex_reg") | |
out | |
} | |
set_fit( | |
model = "convex_reg", | |
eng = "brms", | |
mode = "regression", | |
value = list( | |
interface = "formula", | |
protect = c("formula", "data"), | |
func = c(fun = "convex_regression"), | |
defaults = list() | |
) | |
) | |
num_info <- | |
pred_value_template( | |
pre = NULL, | |
post = function(results, object) { | |
results %>% | |
as_tibble() %>% | |
rename(.pred = Estimate) %>% | |
pull(.pred) | |
}, | |
func = c(fun = "predict"), | |
object = quote(object$fit), | |
newdata = quote(new_data), | |
type = "response" | |
) | |
set_pred( | |
model = "convex_reg", | |
eng = "brms", | |
mode = "regression", | |
type = "numeric", | |
value = num_info | |
) | |
raw_info <- | |
pred_value_template( | |
pre = NULL, | |
post = NULL, | |
func = c(fun = "predict"), | |
object = quote(object$fit), | |
newdata = quote(new_data), | |
type = "response" | |
) | |
set_pred( | |
model = "convex_reg", | |
eng = "brms", | |
mode = "regression", | |
type = "raw", | |
value = raw_info | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment