Skip to content

Instantly share code, notes, and snippets.

@thierrymoudiki
Created January 8, 2026 09:54
Show Gist options
  • Select an option

  • Save thierrymoudiki/cd3a21f5498797aa57b8e931642c7c63 to your computer and use it in GitHub Desktop.

Select an option

Save thierrymoudiki/cd3a21f5498797aa57b8e931642c7c63 to your computer and use it in GitHub Desktop.
formulas vs matrix interface
#' Convert formula-based model function to matrix interface
#'
#' @param fit_func Function accepting formula and data
#' @param predict_func Optional prediction function
#' @param intercept Include intercept in formula (default TRUE)
#' @return List with fit and predict methods
formula_to_matrix <- function(fit_func, predict_func = NULL, intercept = TRUE) {
fit_wrapper <- function(X, y, ...) {
# Input validation
if (!is.matrix(X) && !is.data.frame(X)) {
stop("X must be a matrix or data frame")
}
if (length(y) != nrow(X)) {
stop("Length of y must match number of rows in X")
}
# Convert to data frame and ensure column names exist
df <- as.data.frame(X)
if (is.null(colnames(X))) {
names(df) <- paste0("V", seq_len(ncol(X)))
}
# Create unique response name to avoid conflicts
y_name <- ".response"
while (y_name %in% names(df)) {
y_name <- paste0(y_name, "_")
}
df[[y_name]] <- y
# Build formula with intercept control
rhs <- if (intercept) {
paste(setdiff(names(df), y_name), collapse = " + ")
} else {
paste0("0 + ", paste(setdiff(names(df), y_name), collapse = " + "))
}
formula <- as.formula(paste(y_name, "~", rhs))
# Fit model
model <- fit_func(formula, data = df, ...)
# Store metadata (column names only, not full data)
attr(model, "._x_cols") <- setdiff(names(df), y_name)
attr(model, "._y_name") <- y_name
return(model)
}
predict_wrapper <- if (!is.null(predict_func)) {
function(model, newX, ...) {
if (!is.matrix(newX) && !is.data.frame(newX)) {
stop("newX must be a matrix or data frame")
}
newdf <- as.data.frame(newX)
x_cols <- attr(model, "._x_cols")
# Ensure column names match training
if (is.null(names(newdf))) {
names(newdf) <- paste0("V", seq_len(ncol(newdf)))
}
if (!all(x_cols %in% names(newdf))) {
stop("newX missing required columns: ",
paste(setdiff(x_cols, names(newdf)), collapse = ", "))
}
predict_func(model, newdata = newdf, ...)
}
} else {
function(model, newX, ...) {
stop("No predict function provided")
}
}
structure(
list(fit = fit_wrapper, predict = predict_wrapper),
class = "model_adapter"
)
}
#' Convert matrix-based model function to formula interface
#'
#' @param fit_func Function accepting X matrix and y vector
#' @param predict_func Optional prediction function
#' @param drop_intercept Remove intercept column from model.matrix (default TRUE)
#' @return List with fit and predict methods
matrix_to_formula <- function(fit_func, predict_func = NULL, drop_intercept = TRUE) {
fit_wrapper <- function(formula, data, ...) {
# Input validation
if (!inherits(formula, "formula")) {
stop("formula must be a formula object")
}
if (!is.data.frame(data)) {
stop("data must be a data frame")
}
# Extract response and design matrix
mf <- model.frame(formula, data = data)
if (nrow(mf) == 0) {
stop("model.frame produced empty result")
}
y <- model.response(mf)
X <- model.matrix(formula, data = mf)
terms_obj <- terms(mf)
# Handle intercept: most matrix-based learners don't want it
intercept_removed <- FALSE
if (drop_intercept && "(Intercept)" %in% colnames(X)) {
X <- X[, colnames(X) != "(Intercept)", drop = FALSE]
intercept_removed <- TRUE
}
# Fit model
model <- fit_func(X, y, ...)
# Store metadata for prediction
attr(model, "._formula") <- formula
attr(model, "._terms") <- terms_obj
attr(model, "._x_cols") <- colnames(X)
attr(model, "._intercept_removed") <- intercept_removed
return(model)
}
predict_wrapper <- if (!is.null(predict_func)) {
function(model, newdata, ...) {
if (!is.data.frame(newdata)) {
stop("newdata must be a data frame")
}
# Reconstruct design matrix using stored terms
terms_obj <- attr(model, "._terms")
newX <- model.matrix(delete.response(terms_obj), data = newdata)
# Apply same intercept handling as training
if (attr(model, "._intercept_removed") && "(Intercept)" %in% colnames(newX)) {
newX <- newX[, colnames(newX) != "(Intercept)", drop = FALSE]
}
# Ensure column alignment with training
x_cols <- attr(model, "._x_cols")
if (!all(x_cols %in% colnames(newX))) {
stop("Prediction data missing required columns")
}
newX <- newX[, x_cols, drop = FALSE]
predict_func(model, newX, ...)
}
} else {
function(model, newdata, ...) {
stop("No predict function provided")
}
}
structure(
list(fit = fit_wrapper, predict = predict_wrapper),
class = "model_adapter"
)
}
#' Print method for model adapters
#' @export
print.model_adapter <- function(x, ...) {
cat("Model Interface Adapter\n")
cat(" fit: function(", paste(names(formals(x$fit)), collapse = ", "), ")\n", sep = "")
cat(" predict: function(", paste(names(formals(x$predict)), collapse = ", "), ")\n", sep = "")
invisible(x)
}
# ============================================================================
# EXAMPLES
# ============================================================================
# Example 1: Use lm with matrix interface
demo_lm_matrix <- function() {
cat("\n=== Example 1: lm with matrix interface ===\n")
lm_matrix <- formula_to_matrix(lm, predict)
X <- as.matrix(mtcars[, c("wt", "hp")])
y <- mtcars$mpg
model <- lm_matrix$fit(X, y)
preds <- lm_matrix$predict(model, X[1:5, ])
cat("Predictions:\n")
print(preds)
cat("Coefficients:\n")
print(coef(model))
}
# Example 2: Use glmnet with formula interface
demo_glmnet_formula <- function() {
cat("\n=== Example 2: glmnet with formula interface ===\n")
if (!requireNamespace("glmnet", quietly = TRUE)) {
cat("glmnet package not installed, skipping example\n")
return(invisible(NULL))
}
glmnet_formula <- matrix_to_formula(
fit_func = glmnet::glmnet,
predict_func = function(model, newX, ...) {
# Wrapper to add default s parameter
glmnet::predict.glmnet(model, newx = newX, s = 0.01, ...)
}
)
model <- glmnet_formula$fit(mpg ~ wt + hp, data = mtcars)
preds <- glmnet_formula$predict(model, newdata = mtcars[1:5, ])
cat("Predictions:\n")
print(preds)
}
# Run examples (uncomment to test)
demo_lm_matrix()
demo_glmnet_formula()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment