Created
January 8, 2026 09:54
-
-
Save thierrymoudiki/cd3a21f5498797aa57b8e931642c7c63 to your computer and use it in GitHub Desktop.
formulas vs matrix interface
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #' Convert formula-based model function to matrix interface | |
| #' | |
| #' @param fit_func Function accepting formula and data | |
| #' @param predict_func Optional prediction function | |
| #' @param intercept Include intercept in formula (default TRUE) | |
| #' @return List with fit and predict methods | |
| formula_to_matrix <- function(fit_func, predict_func = NULL, intercept = TRUE) { | |
| fit_wrapper <- function(X, y, ...) { | |
| # Input validation | |
| if (!is.matrix(X) && !is.data.frame(X)) { | |
| stop("X must be a matrix or data frame") | |
| } | |
| if (length(y) != nrow(X)) { | |
| stop("Length of y must match number of rows in X") | |
| } | |
| # Convert to data frame and ensure column names exist | |
| df <- as.data.frame(X) | |
| if (is.null(colnames(X))) { | |
| names(df) <- paste0("V", seq_len(ncol(X))) | |
| } | |
| # Create unique response name to avoid conflicts | |
| y_name <- ".response" | |
| while (y_name %in% names(df)) { | |
| y_name <- paste0(y_name, "_") | |
| } | |
| df[[y_name]] <- y | |
| # Build formula with intercept control | |
| rhs <- if (intercept) { | |
| paste(setdiff(names(df), y_name), collapse = " + ") | |
| } else { | |
| paste0("0 + ", paste(setdiff(names(df), y_name), collapse = " + ")) | |
| } | |
| formula <- as.formula(paste(y_name, "~", rhs)) | |
| # Fit model | |
| model <- fit_func(formula, data = df, ...) | |
| # Store metadata (column names only, not full data) | |
| attr(model, "._x_cols") <- setdiff(names(df), y_name) | |
| attr(model, "._y_name") <- y_name | |
| return(model) | |
| } | |
| predict_wrapper <- if (!is.null(predict_func)) { | |
| function(model, newX, ...) { | |
| if (!is.matrix(newX) && !is.data.frame(newX)) { | |
| stop("newX must be a matrix or data frame") | |
| } | |
| newdf <- as.data.frame(newX) | |
| x_cols <- attr(model, "._x_cols") | |
| # Ensure column names match training | |
| if (is.null(names(newdf))) { | |
| names(newdf) <- paste0("V", seq_len(ncol(newdf))) | |
| } | |
| if (!all(x_cols %in% names(newdf))) { | |
| stop("newX missing required columns: ", | |
| paste(setdiff(x_cols, names(newdf)), collapse = ", ")) | |
| } | |
| predict_func(model, newdata = newdf, ...) | |
| } | |
| } else { | |
| function(model, newX, ...) { | |
| stop("No predict function provided") | |
| } | |
| } | |
| structure( | |
| list(fit = fit_wrapper, predict = predict_wrapper), | |
| class = "model_adapter" | |
| ) | |
| } | |
| #' Convert matrix-based model function to formula interface | |
| #' | |
| #' @param fit_func Function accepting X matrix and y vector | |
| #' @param predict_func Optional prediction function | |
| #' @param drop_intercept Remove intercept column from model.matrix (default TRUE) | |
| #' @return List with fit and predict methods | |
| matrix_to_formula <- function(fit_func, predict_func = NULL, drop_intercept = TRUE) { | |
| fit_wrapper <- function(formula, data, ...) { | |
| # Input validation | |
| if (!inherits(formula, "formula")) { | |
| stop("formula must be a formula object") | |
| } | |
| if (!is.data.frame(data)) { | |
| stop("data must be a data frame") | |
| } | |
| # Extract response and design matrix | |
| mf <- model.frame(formula, data = data) | |
| if (nrow(mf) == 0) { | |
| stop("model.frame produced empty result") | |
| } | |
| y <- model.response(mf) | |
| X <- model.matrix(formula, data = mf) | |
| terms_obj <- terms(mf) | |
| # Handle intercept: most matrix-based learners don't want it | |
| intercept_removed <- FALSE | |
| if (drop_intercept && "(Intercept)" %in% colnames(X)) { | |
| X <- X[, colnames(X) != "(Intercept)", drop = FALSE] | |
| intercept_removed <- TRUE | |
| } | |
| # Fit model | |
| model <- fit_func(X, y, ...) | |
| # Store metadata for prediction | |
| attr(model, "._formula") <- formula | |
| attr(model, "._terms") <- terms_obj | |
| attr(model, "._x_cols") <- colnames(X) | |
| attr(model, "._intercept_removed") <- intercept_removed | |
| return(model) | |
| } | |
| predict_wrapper <- if (!is.null(predict_func)) { | |
| function(model, newdata, ...) { | |
| if (!is.data.frame(newdata)) { | |
| stop("newdata must be a data frame") | |
| } | |
| # Reconstruct design matrix using stored terms | |
| terms_obj <- attr(model, "._terms") | |
| newX <- model.matrix(delete.response(terms_obj), data = newdata) | |
| # Apply same intercept handling as training | |
| if (attr(model, "._intercept_removed") && "(Intercept)" %in% colnames(newX)) { | |
| newX <- newX[, colnames(newX) != "(Intercept)", drop = FALSE] | |
| } | |
| # Ensure column alignment with training | |
| x_cols <- attr(model, "._x_cols") | |
| if (!all(x_cols %in% colnames(newX))) { | |
| stop("Prediction data missing required columns") | |
| } | |
| newX <- newX[, x_cols, drop = FALSE] | |
| predict_func(model, newX, ...) | |
| } | |
| } else { | |
| function(model, newdata, ...) { | |
| stop("No predict function provided") | |
| } | |
| } | |
| structure( | |
| list(fit = fit_wrapper, predict = predict_wrapper), | |
| class = "model_adapter" | |
| ) | |
| } | |
| #' Print method for model adapters | |
| #' @export | |
| print.model_adapter <- function(x, ...) { | |
| cat("Model Interface Adapter\n") | |
| cat(" fit: function(", paste(names(formals(x$fit)), collapse = ", "), ")\n", sep = "") | |
| cat(" predict: function(", paste(names(formals(x$predict)), collapse = ", "), ")\n", sep = "") | |
| invisible(x) | |
| } | |
| # ============================================================================ | |
| # EXAMPLES | |
| # ============================================================================ | |
| # Example 1: Use lm with matrix interface | |
| demo_lm_matrix <- function() { | |
| cat("\n=== Example 1: lm with matrix interface ===\n") | |
| lm_matrix <- formula_to_matrix(lm, predict) | |
| X <- as.matrix(mtcars[, c("wt", "hp")]) | |
| y <- mtcars$mpg | |
| model <- lm_matrix$fit(X, y) | |
| preds <- lm_matrix$predict(model, X[1:5, ]) | |
| cat("Predictions:\n") | |
| print(preds) | |
| cat("Coefficients:\n") | |
| print(coef(model)) | |
| } | |
| # Example 2: Use glmnet with formula interface | |
| demo_glmnet_formula <- function() { | |
| cat("\n=== Example 2: glmnet with formula interface ===\n") | |
| if (!requireNamespace("glmnet", quietly = TRUE)) { | |
| cat("glmnet package not installed, skipping example\n") | |
| return(invisible(NULL)) | |
| } | |
| glmnet_formula <- matrix_to_formula( | |
| fit_func = glmnet::glmnet, | |
| predict_func = function(model, newX, ...) { | |
| # Wrapper to add default s parameter | |
| glmnet::predict.glmnet(model, newx = newX, s = 0.01, ...) | |
| } | |
| ) | |
| model <- glmnet_formula$fit(mpg ~ wt + hp, data = mtcars) | |
| preds <- glmnet_formula$predict(model, newdata = mtcars[1:5, ]) | |
| cat("Predictions:\n") | |
| print(preds) | |
| } | |
| # Run examples (uncomment to test) | |
| demo_lm_matrix() | |
| demo_glmnet_formula() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment