Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save timcdlucas/cc879284cd5e353be4a3f22b6fdbe06e to your computer and use it in GitHub Desktop.
Save timcdlucas/cc879284cd5e353be4a3f22b6fdbe06e to your computer and use it in GitHub Desktop.
library(caret)
library(dplyr)
data(diamonds)
train_i <- sample(c(FALSE, TRUE), nrow(diamonds), replace = TRUE, prob = c(0.2, 0.8))
diamonds_train <- diamonds[train_i, ]
diamonds_test <- diamonds[!train_i, ]
folds <- createFolds(diamonds_train$price, k = 5, returnTrain = TRUE)
trcntrl <- trainControl(index = folds, savePredictions = TRUE, search = 'grid')
m1 <- train(price ~ carat + cut + depth + table,
data = diamonds_train,
method = 'enet',
tuneLength = 10,
trControl = trcntrl)
m1$results$RMSE %>% min
m1_error <- sqrt(mean((predict(m1, diamonds_test) - diamonds_test$price) ^ 2))
m2 <- train(price ~ carat + cut + depth + table,
data = diamonds_train,
method = 'ranger',
tuneLength = 4,
trControl = trcntrl)
m2$results$RMSE %>% min
m2_error <- sqrt(mean((predict(m2, diamonds_test) - diamonds_test$price) ^ 2))
m3 <- train(price ~ carat + cut + depth + table,
data = diamonds_train,
method = 'nnet',
tuneLength = 10,
trControl = trcntrl,
linout = TRUE)
m3$results$RMSE %>% min
m3_error <- sqrt(mean((predict(m3, diamonds_test) - diamonds_test$price) ^ 2))
df <- data.frame(method = c('enet', 'rf', 'nnet'),
validation = c(min(m1$results$RMSE), min(m2$results$RMSE), min(m3$results$RMSE)),
test = c(m1_error, m2_error, m3_error))
ggplot(df, aes(validation, test, colour = method)) +
geom_point() +
geom_abline(slope = 1, intercept = 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment