Skip to content

Instantly share code, notes, and snippets.

@TheRayTracer
Last active August 14, 2017 09:44
Show Gist options
  • Save TheRayTracer/f4424668ae486c5df678 to your computer and use it in GitHub Desktop.
Save TheRayTracer/f4424668ae486c5df678 to your computer and use it in GitHub Desktop.
This R script demonstrates Machine Learning and classification using the Caret package. First, the Iris set is divided into a training set and a test set. Secondly, the model is trained. Then, predictions are queried and accuracy is calculated. Lastly, RoC curves are plotted for evaluation. Plots are generated for each step.
# Load our data set.
data(iris);
# Load libraries.
library(ggplot2);
library(caret);
library(gridExtra);
set.seed(100);
# Inspect our data set.
names(iris);
head(iris);
# Plot the data.
qp0 = qplot(Petal.Width, Petal.Length, color = Species, data = iris, main = "Full Iris data set (100%)");
iris_k <- kmeans(iris[, c("Petal.Length", "Petal.Width")], 3, nstart = 20);
print(iris_k);
table(iris_k$cluster, iris$Species);
# Plot unsupervised clustering.
iris$cluster <- as.factor(iris_k$cluster);
qp1 <- qplot(Petal.Width, Petal.Length, color = cluster, data = iris, main = "K means clustered, k = 3");
grid.arrange(qp0, qp1, ncol = 2, nrow = 1);
iris$cluster <- NULL;
set.seed(100);
# Split our data into a training (70%) and test (30%) set.
training_split <- createDataPartition(y = iris$Species, p = 0.70, list = FALSE);
training_set <- iris[training_split,];
testing_set <- iris[-training_split,];
# Plot our training data
qp2 <- qplot(Petal.Width, Petal.Length, color = Species, data = training_set, main = "Training data set (70%)");
# Plot our test data
qp3 <- qplot(Petal.Width, Petal.Length, color = Species, data = testing_set, main = "Testing data set (30%)");
# Train our model.
model_fit <- train(Species ~ ., method = "rpart", data = training_set, trControl = trainControl(method = 'cv', number = 10, classProbs = TRUE));
# Using all features.
# ...or model_fit = train(Species ~ Petal.Length + Petal.Width, method = "rpart", data = training_set); to use only selected features.
print(model_fit);
# Classify from our reserved test set.
testing_set_predict <- predict(model_fit, newdata = testing_set[, -5]);
# Remove the Species field to prove we are not cheating.
# Verifying our model from the classifications.
table(testing_set_predict, testing_set$Species);
testing_set$Correct <- (testing_set_predict == testing_set$Species);
accuracy <- length(testing_set$Correct[testing_set$Correct == TRUE]) / length(testing_set$Correct);
paste("Training accuracy:", accuracy);
# Plot the classification results.
qp4 <- qplot(Petal.Width, Petal.Length, color = Correct, data = testing_set, main = "Test set classification results");
grid.arrange(qp0, qp2, qp3, qp4, ncol = 2, nrow = 2);
# Perform principal component analysis (by hand).
analysis <- prcomp(training_set[, -5], scale. = TRUE);
print(analysis$sdev);
variance_vector <- analysis$sdev ^ 2;
print(variance_vector);
relative_variance <- variance_vector / sum(variance_vector);
cumsum(relative_variance);
# Looking at the results, only 3 principal components needed to capture 99% of the variance.
# Using the Caret package to confirm results.
preProcess(training_set[, -5], method = "pca", thresh = 0.99);
# Generate RoC curves.
testing_set_predict <- predict(model_fit, newdata = testing_set[, -5], "prob");
roc_setosa <- data.frame(testing_set_predict$setosa, testing_set$Species == 'setosa');
colnames(roc_setosa) <- c("predict", "label");
roc_versicolor <- data.frame(testing_set_predict$versicolor, testing_set$Species == 'versicolor');
colnames(roc_versicolor) <- c("predict", "label");
roc_virginica <- data.frame(testing_set_predict$virginica, testing_set$Species == 'virginica');
colnames(roc_virginica) <- c("predict", "label");
library(ROCR);
pred_setosa <- prediction(roc_setosa$predict, roc_setosa$label);
perf_setosa <- performance(pred_setosa, "tpr", "fpr");
pred_versicolor <- prediction(roc_versicolor$predict, roc_versicolor$label);
perf_versicolor <- performance(pred_versicolor, "tpr", "fpr");
pred_virginica <- prediction(roc_virginica$predict, roc_virginica$label);
perf_virginica <- performance(pred_virginica, "tpr", "fpr");
qproc1 <- qplot([email protected][[1]], [email protected][[1]], xlab = "FP Rate", ylab = "TP Rate", main = "Setosa", geom = "path");
qproc2 <- qplot([email protected][[1]], [email protected][[1]], xlab = "FP Rate", ylab = "TP Rate", main = "Versicolor", geom = "path");
qproc3 <- qplot([email protected][[1]], [email protected][[1]], xlab = "FP Rate", ylab = "TP Rate", main = "Virginica", geom = "path");
grid.arrange(qproc1, qproc2, qproc3, ncol = 2, nrow = 2);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment