Created
June 14, 2016 21:48
-
-
Save arthurwuhoo/60d0c7afb6b36b835569e4656d2a62df to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ------------------------------------------------------------------ | |
# DAY 12 EXERCISES - DECISION TREES | |
# ------------------------------------------------------------------ | |
# ------------------------------------------------------------------ | |
# EXERCISE 1 | |
# Complete the iris modelling exercise. This is a multiclass problem. Some models | |
# support multiclass problems, others don’t. Decision trees do. Divide the data | |
# in a 60% training and 40% testing split. Create a model based on the training | |
# data. | |
# ------------------------------------------------------------------ | |
install.packages(c("rpart","rpart.plot","rattle")) | |
library(rpart) | |
library(rpart.plot) | |
library(rattle) | |
percent_training <- 0.6 | |
n_training_obs <- floor(percent_training*nrow(iris)) | |
train_ind <- sample(1:nrow(iris), size = n_training_obs) | |
iris.train <- iris[train_ind, ] | |
iris.test <- iris[-train_ind,] | |
iris.rpart <- rpart(Species ~ ., data = iris.train) | |
fancyRpartPlot(iris.rpart) | |
# What is the accuracy of your training model on you training set? | |
# ------------------------------------------------------------------ | |
predictions <- predict(iris.rpart, iris.train, type = "class") | |
confusion.matrix <- prop.table(table(predictions, iris.train$Species)) #also another way to build the confusion matrix | |
accuracy <- confusion.matrix[1,1] + confusion.matrix[2,2] + confusion.matrix[3,3] | |
accuracy #99% accuracy | |
# What are the most important node splits in your model? Use rattle’s fancyRpartPlot() function to visualize your model. | |
# ------------------------------------------------------------------ | |
fancyRpartPlot(iris.rpart) | |
#most important - petal.length <2.6 | |
#second most important - petal.length < 4.8 | |
# What is the accuracy of your model on your testing set? | |
# ------------------------------------------------------------------ | |
predictions <- predict(iris.rpart, iris.test, type = "class") | |
confusion.matrix <- prop.table(table(predictions, iris.test$Species)) #also another way to build the confusion matrix | |
accuracy <- confusion.matrix[1,1] + confusion.matrix[2,2] + confusion.matrix[3,3] | |
accuracy #90% accuracy | |
# ------------------------------------------------------------------ | |
# ------------------------------------------------------------------ | |
# EXERCISE 2 | |
# Using the Pima Indians Diabetes data | |
# (PimaIndiansDiabetes or PimaIndiansDiabetes2) from the mlbench package. | |
# ------------------------------------------------------------------ | |
# Install and load the mlbench package. | |
# Use data(PimaIndiansDiabetes) to load the data. | |
# Divide the data set into an 80% training and 20% testing split. | |
# Make a model that achieves 100% accuracy on your training set. | |
# Apply that model to your test set. | |
# ------------------------------------------------------------------ | |
install.packages("mlbench") | |
library(mlbench) | |
data(PimaIndiansDiabetes) | |
pid <- PimaIndiansDiabetes | |
str(pid) | |
#splitting data | |
index <- createDataPartition(pid$diabetes, p = 0.8)[[1]] | |
pid.train <- pid[index,] | |
pid.test <- pid[-index,] | |
#running rpart | |
pid.rpart <- rpart(diabetes ~ ., data = pid.train, | |
control = rpart.control(minsplit = 2, maxdepth = 30, cp = 0.0001)) | |
#the rpart.control parameters adjust how we're fitting the tree. | |
#the default settings have set minsplit as 20 (meaning the tree stops trying to | |
#grow a certain branch after the nodes are 20 observations or less), so | |
#change that to 2 make its find even more rules so that it can eventually | |
#categorize every observation possible. | |
fancyRpartPlot(pid.rpart) # | |
predictions <- predict(pid.rpart, pid.train, type = "class") | |
conf.matrix <- prop.table(table(predictions,pid.train$diabetes)) | |
accuracy <- sum(diag(conf.matrix)) #100% accuracy here. | |
# ------------------------------------------------------------------ | |
# What is the accuracy of the model on your testing set? | |
# Conceptually, why is this model so much worse on the testing data? | |
# What would you do differently to avoid this problem? | |
# ------------------------------------------------------------------ | |
predictions <- predict(pid.rpart, pid.test, type = "class") | |
conf.matrix <- prop.table(table(predictions,pid.test$diabetes)) | |
accuracy <- sum(diag(conf.matrix)) #only 68% accuracy here. we could definitely | |
# see the effect of OVERFITTING - the test data has some noise in it that | |
# our trained model took too seriously. remember, the super accurate model | |
# trained itself on both the underlying signal (which is good), but also on | |
# its noise (i.e. random variation) as well, which is definitely bad. | |
#you could overcome this problem by *not* trying to achieve maximum accuracy | |
#on the training set. for example, let's even just use the default rpart settings. | |
pid.rpart_default <- rpart(diabetes ~ ., data = pid.train) | |
train_predictions <- predict(pid.rpart_default, pid.train, type = "class") | |
conf.matrix <- prop.table(table(train_predictions,pid.train$diabetes)) | |
train_accuracy <- sum(diag(conf.matrix)) #83% accuracy, not 100%. | |
test_predictions <- predict(pid.rpart_default, pid.test, type = "class") | |
conf.matrix <- prop.table(table(test_predictions,pid.test$diabetes)) | |
test_accuracy <- sum(diag(conf.matrix)) | |
#77%, which is a ton better than the 68% achieved earlier with the 100% accurate training model. | |
# ------------------------------------------------------------------ | |
# ------------------------------------------------------------------ | |
# EXERCISE 3 | |
# Use a decision tree to make a regression for crime in the Boston data from the MASS package. Compare its RMSE (use rmse() from the Metrics package) to the RMSE from a simple linear regression model. Which one is better? Why? | |
# ------------------------------------------------------------------ | |
# ------------------------------------------------------------------ | |
library("MASS") | |
head(cars) | |
# Build a model which will classify names according to gender. Potential features might be: length of name, number of vowels in name, whether name begins/ends with a vowel. | |
library(rattle) | |
library(rpart) | |
library(rpart.plot) | |
library(MASS) | |
#devise training set | |
percent_training <- 0.6 | |
n_training_obs <- floor(percent_training*nrow(Boston)) | |
train_ind <- sample(1:nrow(Boston), size = n_training_obs) | |
Boston.train <- Boston[train_ind,] | |
Boston.test <- Boston[-train_ind,] | |
#run the decision tree | |
head(Boston) | |
fit.dtree_regress <- rpart(crim ~., data = Boston.train) | |
fancyRpartPlot(fit.dtree_regress) #cool, ok. | |
dtree_predictions <- predict(fit.dtree_regress, Boston.test) | |
hist(dtree_predictions) #notice here that all of the predictions are | |
#essentially one of five numbers, per the decision tree that we just made. | |
#this probably isn't the best idea to use a predictive tool. | |
#let's find the MSE of this model using the rmse function of the Metrics package | |
library(Metrics) | |
rmse_dtree <- rmse(dtree_predictions, Boston.test$crim) | |
## now lets try a linear regression using GLM | |
fit.lm1 <- lm(crim~., Boston.train) | |
summary(fit.lm1) # a lot of insignificant varaibles here. | |
fit.lm2 <- lm(crim ~ . - chas - nox - rm - age - tax - ptratio - lstat, Boston.train) | |
summary(fit.lm2) #ok, fit still not the best. could feature engineer | |
# or test out polynomials/interactions to get to glory, but let's leave it as here | |
lm_predictions <- predict(fit.lm2, Boston.test) | |
rmse_lm <- rmse(lm_predictions, Boston.test$crim) | |
rmse_lm | |
rmse_dtree |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment