Last active
October 18, 2016 18:02
-
-
Save talegari/6ae581755a263b86a5613778a27f2f4e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
############################################################################### | |
# | |
# optimal_poly ---- | |
# | |
############################################################################### | |
# | |
# author : Srikanth KS (talegari) | |
# license : GNU AGPLv3 (http://choosealicense.com/licenses/agpl-3.0/) | |
# | |
############################################################################### | |
# | |
# Description ---- | |
# | |
# Find optimal degree and r-squared for a univariable polynomial fit by | |
# plotting fitting for each degree and averaging the r-squared over CV folds. | |
# The optimal degree of the polynomial is decided using either: | |
# a. heuristic in `optimal_scree` function | |
# b. `dim_select` based on profile likelihood from igraph package. | |
# A scree-plot is provided for the user to choose a | |
# better optimal value. Choose numfolds so that number of predictors remains | |
# smaller than the number of observations for the linear model. | |
# | |
# Expected to be used in REPL use, hence verbose is set to TRUE by default. | |
# | |
# Arguments ---- | |
# | |
# response : numeric vector. | |
# predictor : numeric vector. | |
# maxDegree : positive integer indicating max degree of dependence. | |
# Defaults to 5. | |
# numFolds : Positive integer at least 3, indicating the number of folds to be | |
# used for cross-validation error. | |
# Defaults to 10. | |
# method : string indicating either "lm"(default) or "rlm". | |
# detectBy : "heuristic" for `optimal_scree`(default) and "profLik" for a | |
# method based on profile likelihood from `igraph` package. | |
# verbose : a boolean, TRUE by default. | |
# ... : additional arguments to be passed to `optimal_scree`. | |
# | |
# Value ---- | |
# | |
# A list with these elements is returned: | |
# optimal : optimal degree of the polynomial fit. | |
# optimal_r2 : r-squared value corresponding to 'optimal'. | |
# average_r2 : r-squared value for degree in 1 to maxDegree averaged over CV | |
# folds. | |
# sd_r2 : sd of r-squared values for degree in 1 to maxDegree averaged | |
# over CV folds. | |
# scree_plot : "ggplot" plot object plotting r-squared versus the degree. | |
# | |
# depends ---- | |
# | |
# R : >= 3 | |
# Packages: `assertthat`, `magrittr`, `MASS`, `ggplot2` and `igraph`. | |
# | |
# | |
# function ---- | |
optimal_poly <- function(response | |
, predictor | |
, maxDegree = 5 | |
, numFolds = 10 | |
, method = "lm" | |
, detectBy = "heuristic" | |
, verbose = TRUE | |
, ...){ | |
extras <- list(...) | |
# assertions | |
stopifnot(require("assertthat" | |
, quietly = TRUE | |
, warn.conflicts = FALSE | |
, character.only = TRUE) | |
) | |
lapply(c("magrittr", "MASS", "ggplot2") | |
, function(x){ | |
assert_that(require(x | |
, quietly = TRUE | |
, warn.conflicts = FALSE | |
, character.only = TRUE | |
) | |
) | |
} | |
) | |
assert_that(length(response) == length(predictor)) | |
assert_that(is.numeric(response) && is.numeric(predictor)) | |
rm_index <- union(which(is.na(response)), which(is.na(predictor))) | |
if(length(rm_index) != 0){ | |
predictor <- predictor[-rm_index] | |
response <- response[-rm_index] | |
} | |
assert_that(is.count(maxDegree)) | |
assert_that(is.count(numFolds) && numFolds >= 3) | |
assert_that(length(response) >= numFolds) | |
assert_that(method %in% c("lm", "rlm")) | |
assert_that(detectBy %in% c("heuristic", "profLik")) | |
seed <- sample(1:1000, 1) | |
set.seed(seed) | |
cvFolds <- caret::createFolds(response | |
, numFolds | |
, returnTrain = TRUE | |
) | |
# get r^2 for fixed degree | |
# resulting r2_list is a numeric vector and not a list | |
get_test_r2 <- function(degree, folds){ | |
r2_list <- | |
vapply(folds | |
, function(foldData){ | |
localD <- data.frame( y = response[foldData] | |
, x = predictor[foldData] | |
) | |
args <- list(y ~ poly(x, degree), data = localD) | |
lmm <- do.call(method, args) | |
te_res <- predict(lmm, data.frame(x = predictor[-foldData])) | |
enn <- length(predictor[foldData]) | |
te_tss <- response[-foldData] %>% | |
raise_to_power(2) %>% | |
sum(na.rm = TRUE) %>% | |
divide_by(enn - 1) | |
te_rss <- (response[-foldData] - te_res) %>% | |
raise_to_power(2) %>% | |
sum(na.rm = TRUE) %>% | |
divide_by(enn - degree - 1) | |
te_r2 <- 1 - (te_rss/te_tss) | |
return(te_r2) | |
} | |
, numeric(1)) | |
return(r2_list) | |
} | |
list_mat <- lapply(1:maxDegree | |
, function(x){get_test_r2(x, cvFolds)}) | |
te_mat <- do.call(cbind, list_mat) | |
cm <- colMeans(te_mat) | |
# select number of dims | |
if(method == "profLik"){ | |
assert_that(require("igraph")) | |
opt_r2 <- dim_select(cm) | |
opt_deg <- Position(function(x){ x == opt_r2 }, 1:length(te_mat)) | |
} else{ | |
if(length(extras) == 0){ | |
opt_deg <- optimal_scree(cm) | |
} else{ | |
opt_deg <- do.call(optimal_scree, c(list(vec = cm), extras)) | |
} | |
} | |
pl <- qplot(1:maxDegree | |
, cm | |
, geom = c("line", "point") | |
, xlab = "Degree of the Polynomial" | |
, ylab = "r-squared") + | |
geom_vline(xintercept = opt_deg, col = "blue") + | |
scale_x_continuous(breaks = 1:maxDegree) | |
if(verbose){ | |
message("\n****\nOptimal degree: ", opt_deg) | |
message("Optimal r-squared: ", round(cm[opt_deg], 4)) | |
message("Seed chosen to generate CV folds: ", seed, "\n****\n") | |
print(pl) | |
} | |
return(list(optimal = opt_deg | |
, optimal_r2 = cm[opt_deg] | |
, average_r2 = cm | |
, sd_r2 = apply(te_mat | |
, 2 | |
, function(x){sd(x, na.rm = TRUE)}) | |
, scree_plot = pl) | |
) | |
} | |
# example ---- | |
pre <- seq(from = -pi, to = pi, by = 0.001) | |
res <- sin(pre) | |
set.seed(100) | |
res2 <- res + rnorm(length(res), mean = 0, sd = 0.5) | |
opt <- optimal_poly(res2, pre, method = "rlm", thres = 0.7) | |
opt | |
plot(pre, res2) | |
points(pre, fitted(rlm(res2 ~ poly(pre, 3))), col = "green") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment