Created
September 29, 2015 01:54
-
-
Save mikelove/5aca819e49149bcd4e77 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(purrr) | |
library(dplyr) | |
# some functions | |
# just a convenience function, gives back random assignments | |
# conceptually like: sample(labels, size=n, replace=TRUE, prob=prob) | |
random_group <- function(n, probs) { | |
probs <- probs / sum(probs) | |
g <- findInterval(seq(0, 1, length = n), c(0, cumsum(probs)), | |
rightmost.closed = TRUE) | |
names(probs)[sample(g)] | |
} | |
# concise way to do random partitions | |
partition <- function(df, n, probs) { | |
# this is a list of lists: training, test; training, test, ... | |
replicate(n, split(df, random_group(nrow(df), probs)), FALSE) %>% | |
zip_n() %>% # reorder list of lists: training, training, ... ; test, test, ... | |
as_data_frame() # now a data frame with columns test, training | |
} | |
msd <- function(x, y) sqrt(mean((x - y) ^ 2)) | |
# the code: | |
boot <- partition(mtcars, 100, c(test = 0.8, training = 0.2)) | |
boot <- boot %>% mutate( | |
# here ~ is a shortcut for an anonymous function, i.e. function(x) lm(mpg ~ wt, data=x) | |
models = map(training, ~ lm(mpg ~ wt, data = .)), | |
preds = map2(models, test, predict), | |
diffs = map2(preds, test %>% map("mpg"), msd) | |
) | |
mean(unlist(boot$diffs)) | |
# how would you debug this? | |
fit <- lm(mpg ~ wt, data=boot[[1,"training"]]) | |
pred <- predict(fit, boot[[1,"test"]]) | |
msd(pred, boot[[1,"test"]]$mpg) | |
# this should give the same as: | |
boot[[1,"diffs"]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://github.com/hadley/purrr