Created
April 1, 2023 18:54
-
-
Save topepo/6b9c8e6240f69c719c137fc41614e1a2 to your computer and use it in GitHub Desktop.
example code for a model wrapper for the fuser package
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Allow fuser functions to work in tidymodels | |
# See https://stackoverflow.com/questions/75871678/how-can-i-pass-an-extra-variable-to-a-tidymodels-fit-function | |
# Create a function that can parse a formula that uses a special (non-existant) | |
# `groups()` function. This returns a list of character vectors that has the | |
# columns and their roles based on the formula. It works with '.'. | |
fuser_variables <- function(f, data) { | |
cl <- match.call() | |
trms <- terms(f, data = data, specials = "groups") | |
form_terms <- attr(trms, "variables") | |
groups_ind <- attr(trms, "specials")$groups + 1 | |
outcome_ind <- attr(trms,"response") + 1 | |
# check length | |
if (length(groups_ind) != 1) { | |
rlang::abort( | |
paste( | |
"There should be a single 'groups' column specified using the `groups()`", | |
"function (e.g. `y ~ x + groups(groups_col)`" | |
) | |
) | |
} | |
# find column with groups variable | |
groups_expr <- form_terms[[groups_ind]] | |
groups <- all.vars(groups_expr) | |
outcomes <- all.vars(form_terms[[outcome_ind]]) | |
# repair formula: get predictors and remake | |
rhs <- form_terms[-c(outcome_ind, groups_ind)] | |
rhs <- all.vars(rhs) | |
# If the '.' was used, it puts the group variable in twice | |
rhs <- rhs[rhs != groups] | |
list(y = outcomes, x = rhs, groups = groups) | |
} | |
set.seed(1) | |
nfeats <- 5 | |
nsamples <- 10 | |
ngroups <- 2 | |
group <- sample(letters[1:ngroups], nsamples, replace=TRUE) | |
predictors <- matrix(rnorm(nfeats*nsamples), nrow = nsamples, ncol = nfeats, | |
dimnames = list(paste("Sample", 1:nsamples), paste("Feature", 1:nfeats)) | |
) | |
outcome <- rnorm(nsamples) | |
## tidymodels wants a dataframe input | |
input <- data.frame( | |
predictors, | |
group = group, | |
outcome = nsamples | |
) | |
fuser_variables(outcome ~ groups(group) + ., data = input) | |
# ------------------------------------------------------------------------------ | |
# Fit the model using the variables from the formula | |
fuser_fit <- function(formula, data, ...) { | |
data_roles <- fuser_variables(formula, data) | |
group_vec <- data[[data_roles$groups]] | |
outcomes <- data[, data_roles$y] | |
predictors <- data[, data_roles$x] | |
predictors <- as.matrix(predictors) | |
fuser::fusedLassoProximal(predictors, outcomes, group_vec, ...) | |
} | |
fuser_fit( | |
outcome ~ groups(group) + ., | |
data = input, | |
G = matrix(1, nfeats, nfeats), | |
lambda = .1, | |
gamma = 1 | |
) | |
# ------------------------------------------------------------------------------ | |
# Also fit but add an option for the information sharsing weights | |
# `info_weights` is some expression that is a function of 'k'. | |
fuser_fit_alt <- function(formula, data, info_weights = rlang::expr(matrix(1, k, k)), ...) { | |
data_roles <- fuser_variables(formula, data) | |
group_vec <- data[[data_roles$groups]] | |
outcomes <- data[, data_roles$y] | |
predictors <- data[, data_roles$x] | |
predictors <- as.matrix(predictors) | |
k <- ncol(predictors) | |
G <- rlang::eval_tidy(info_weights) | |
fuser::fusedLassoProximal(predictors, outcomes, group_vec, G = G, ...) | |
} | |
fuser_fit_alt( | |
outcome ~ groups(group) + ., | |
data = input, | |
lambda = .1, | |
gamma = 1 | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment