Skip to content

Instantly share code, notes, and snippets.

@alexhallam
Created May 30, 2018 12:48
Show Gist options
  • Save alexhallam/31ce0d442d09f60595f79996134cef2e to your computer and use it in GitHub Desktop.
Save alexhallam/31ce0d442d09f60595f79996134cef2e to your computer and use it in GitHub Desktop.
Fit many models and get predictions based off of the training data
library(modelr)
library(tidyverse)
library(gapminder)
# nest data by continent and label test/train data
nested_gap <- gapminder %>%
mutate(test_train = ifelse(year < 1992, "train", "test")) %>%
group_by(continent) %>%
nest()
# make a linear model function than only trains on training set
cont_model <- function(df) {
lm(lifeExp ~ year, data = df %>% filter(test_train == "train"))
}
# fit a model and add predictions to all data
fitted_gap <- nested_gap %>%
mutate(model = map(data, cont_model)) %>%
mutate(pred = map2(data, model, add_predictions))
# unnest predictions and filter only the test rows
fitted_gap %>%
unnest(pred) %>%
filter(test_train == "test")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment