Created
June 18, 2024 16:05
-
-
Save topepo/778eff2590df87702e1c82c9ba09af7d to your computer and use it in GitHub Desktop.
function to produce a smooth linear predictor trend for survival models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
smooth_ph_linear_pred <- function(formula, data, deg_free = 5, grid_size = 500) { | |
require(rlang) | |
rlang::is_installed("survival") | |
rlang::is_installed("ggplot2") | |
rlang::is_installed("splines2") | |
rlang::is_installed("cli") | |
require(ggplot2) | |
# check 1 pred and continuous | |
pred_sym <- rlang::f_rhs(formula) | |
pred_var <- all.vars(pred_sym) | |
if ( length(pred_var) != 1 ) { | |
cli::abort("Only a single numeric predictor is supported.") | |
} | |
is_pred_num <- is.numeric(data[[pred_var]]) | |
if ( !is_pred_num ) { | |
cli::abort("The predictor {.val pred_var} should be numeric.") | |
} | |
num_uniq <- length(unique(data[[pred_var]])) | |
if ( num_uniq <= deg_free ) { | |
cli::cli_abort("The predictor has {num_uniq} unique values; more are \\ | |
needed to fit a spline with {deg_free} degrees of freedom.") | |
} | |
# ---------------------------------------------------------------------------- | |
# make prop haz fit | |
new_term <- rlang::call2("naturalSpline", .ns = "splines2", pred_sym, df = deg_free) | |
model_form <- formula | |
rlang::f_rhs(model_form) <- new_term | |
ph_fit <- try(survival::coxph(model_form, data), silent = TRUE) | |
if ( inherits(ph_fit, "try-error") ) { | |
cli::cli_abort("The model fit failed with error {as.character(ph_fit)}.") | |
} | |
# ---------------------------------------------------------------------------- | |
# plot grid | |
pred_data <- sort(unique(rlang::eval_tidy(pred_sym, data))) | |
# used for rug below | |
pred_df <- data.frame(pred_data) | |
names(pred_df) <- pred_var | |
pred_rng <- range(pred_data) | |
pred_grid <- seq(pred_rng[1], pred_rng[2], length.out = grid_size) | |
pred_grid <- data.frame(x = pred_grid) | |
names(pred_grid) <- pred_var | |
pred_grid$linear_predictor <- | |
predict(ph_fit, | |
newdata = pred_grid, | |
type = "lp", | |
se.fit = FALSE) # TRUE appears to triggers a bug | |
ggplot(pred_grid, aes(x = !!pred_sym)) + | |
geom_line(aes(y = linear_predictor)) + | |
geom_rug(data = pred_df) + | |
labs(y = "Linear Predictor") | |
} | |
smooth_ph_linear_pred(event_time ~ latitude, data = cat_adoption) | |
smooth_ph_linear_pred(event_time ~ longitude, data = cat_adoption, deg_free = 10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment