Skip to content

Instantly share code, notes, and snippets.

@topepo
Created April 1, 2023 18:54
Show Gist options
  • Save topepo/6b9c8e6240f69c719c137fc41614e1a2 to your computer and use it in GitHub Desktop.
Save topepo/6b9c8e6240f69c719c137fc41614e1a2 to your computer and use it in GitHub Desktop.
example code for a model wrapper for the fuser package
# Allow fuser functions to work in tidymodels
# See https://stackoverflow.com/questions/75871678/how-can-i-pass-an-extra-variable-to-a-tidymodels-fit-function
# Create a function that can parse a formula that uses a special (non-existant)
# `groups()` function. This returns a list of character vectors that has the
# columns and their roles based on the formula. It works with '.'.
fuser_variables <- function(f, data) {
cl <- match.call()
trms <- terms(f, data = data, specials = "groups")
form_terms <- attr(trms, "variables")
groups_ind <- attr(trms, "specials")$groups + 1
outcome_ind <- attr(trms,"response") + 1
# check length
if (length(groups_ind) != 1) {
rlang::abort(
paste(
"There should be a single 'groups' column specified using the `groups()`",
"function (e.g. `y ~ x + groups(groups_col)`"
)
)
}
# find column with groups variable
groups_expr <- form_terms[[groups_ind]]
groups <- all.vars(groups_expr)
outcomes <- all.vars(form_terms[[outcome_ind]])
# repair formula: get predictors and remake
rhs <- form_terms[-c(outcome_ind, groups_ind)]
rhs <- all.vars(rhs)
# If the '.' was used, it puts the group variable in twice
rhs <- rhs[rhs != groups]
list(y = outcomes, x = rhs, groups = groups)
}
set.seed(1)
nfeats <- 5
nsamples <- 10
ngroups <- 2
group <- sample(letters[1:ngroups], nsamples, replace=TRUE)
predictors <- matrix(rnorm(nfeats*nsamples), nrow = nsamples, ncol = nfeats,
dimnames = list(paste("Sample", 1:nsamples), paste("Feature", 1:nfeats))
)
outcome <- rnorm(nsamples)
## tidymodels wants a dataframe input
input <- data.frame(
predictors,
group = group,
outcome = nsamples
)
fuser_variables(outcome ~ groups(group) + ., data = input)
# ------------------------------------------------------------------------------
# Fit the model using the variables from the formula
fuser_fit <- function(formula, data, ...) {
data_roles <- fuser_variables(formula, data)
group_vec <- data[[data_roles$groups]]
outcomes <- data[, data_roles$y]
predictors <- data[, data_roles$x]
predictors <- as.matrix(predictors)
fuser::fusedLassoProximal(predictors, outcomes, group_vec, ...)
}
fuser_fit(
outcome ~ groups(group) + .,
data = input,
G = matrix(1, nfeats, nfeats),
lambda = .1,
gamma = 1
)
# ------------------------------------------------------------------------------
# Also fit but add an option for the information sharsing weights
# `info_weights` is some expression that is a function of 'k'.
fuser_fit_alt <- function(formula, data, info_weights = rlang::expr(matrix(1, k, k)), ...) {
data_roles <- fuser_variables(formula, data)
group_vec <- data[[data_roles$groups]]
outcomes <- data[, data_roles$y]
predictors <- data[, data_roles$x]
predictors <- as.matrix(predictors)
k <- ncol(predictors)
G <- rlang::eval_tidy(info_weights)
fuser::fusedLassoProximal(predictors, outcomes, group_vec, G = G, ...)
}
fuser_fit_alt(
outcome ~ groups(group) + .,
data = input,
lambda = .1,
gamma = 1
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment