Last active
February 7, 2024 16:11
-
-
Save ivopbernardo/0fb8a90b892892c488e86df07d211033 to your computer and use it in GitHub Desktop.
Caret Library Example - R Language
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# caret library example used in blogpost: | |
# https://towardsdatascience.com/a-guide-to-using-caret-in-r-71dec0bda208 | |
library(caTools) | |
library(caret) | |
# Train Test Split on both Iris and Mtcars | |
train_test_split <- function(df) { | |
set.seed(42) | |
sample = sample.split(df, SplitRatio = 0.8) | |
train = subset(df, sample == TRUE) | |
test = subset(df, sample == FALSE) | |
return (list(train, test)) | |
} | |
# To Simplify, we are going to make iris a two classifiction | |
# problem | |
iris$target <- ifelse( | |
iris$Species == 'setosa', | |
1, | |
0 | |
) | |
# mtcars | |
mtcars_train <- train_test_split(mtcars)[[1]] | |
mtcars_test <- train_test_split(mtcars)[[2]] | |
# iris | |
iris_train <- train_test_split(iris)[[1]] | |
iris_test <- train_test_split(iris)[[2]] | |
# Training a Linear Model | |
lm_model <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
method = "lm") | |
summary(lm_model) | |
# Training a Logistic Regression | |
glm_model <- train(target ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, | |
data = iris_train, | |
method = "glm", | |
family = "binomial") | |
summary(glm_model) | |
# Training a decision tree | |
d.tree <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
method = "rpart") | |
library(rpart.plot) | |
rpart.plot(d.tree$finalModel) | |
# Examining our dtree | |
d.tree | |
# Training a Random Forest | |
r.forest <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
method = "ranger") | |
r.forest | |
# Training a XGBoost Model | |
xg.boost <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
method = "xgbTree") | |
# Training a k-nearest neighbor | |
knn <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
method = "knn") | |
# Training a neural network | |
neural.network <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
method = "neuralnet") | |
# Training a neural network for Classification | |
neural.network.class <- train(target ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, | |
data = iris_train, | |
method = "nnet") | |
# Results of Neural Network Class | |
neural.network.class | |
iris_train$target = factor(iris_train$target) | |
# Training a neural network for Classifiction with Factor | |
neural.network.class.2 <- train(target ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, | |
data = iris_train, | |
method = "nnet") | |
# K-Fold Cross Validation | |
ctrl<- trainControl(method="cv", number=10) | |
d.tree.kfold <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
trControl = ctrl, | |
method = "rpart") | |
d.tree.kfold | |
rpart.plot(d.tree.kfold$finalModel) | |
# Optimizing for Another Metric | |
d.tree.kfold.rsquared <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
trControl = ctrl, | |
metric = "Rsquared", | |
method = "rpart") | |
d.tree.kfold.rsquared | |
# Hyperparameter Tuning | |
d.tree.hyperparam <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
trControl = ctrl, | |
metric = "Rsquared", | |
method = "rpart", | |
tuneGrid = data.frame( | |
cp = c(0.0001, 0.001, 0.35, 0.65))) | |
d.tree.hyperparam | |
r.forest.hyperparam <- train(mpg ~ hp + wt + gear + disp, | |
data = mtcars_train, | |
trControl = ctrl, | |
metric = "Rsquared", | |
method = "ranger", | |
tuneGrid = data.frame( | |
mtry = c(2, 3, 4), | |
min.node.size = c(2, 4, 10), | |
splitrule = c('variance'))) | |
r.forest.hyperparam | |
# Predicting On the Test Set | |
# Linear Regression | |
predict(lm_model, mtcars_test) | |
# Decision Tree | |
predict(d.tree, mtcars_test) | |
# Random Forest | |
predict(r.forest, mtcars_test) | |
# XGBoost | |
predict(xg.boost, mtcars_test) | |
# Comparing Predictions | |
predictions <- data.frame( | |
rf = predict(r.forest, mtcars_test), | |
xgboost = predict(xg.boost, mtcars_test) | |
) | |
library(ggplot2) | |
ggplot( | |
data = predictions, | |
aes(x = rf, y = xgboost) | |
) + geom_point() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment