Created
April 30, 2013 19:02
-
-
Save ArionHardison/5491068 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# package for trees | |
library(rpart) | |
# package including data from Elements of Statistical Learning | |
library(ElemStatLearn) | |
data(spam) | |
# make response a 0-1 outcome | |
#spam$spam = ifelse(spam$spam=="spam",1,0) | |
spam.sub = c(1:nrow(spam))[spam$spam == 'spam'] | |
nospam.sub = c(1:nrow(spam))[spam$spam == 'email'] | |
# use 2/3 for training, 1/3 for test | |
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) | |
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) | |
train = c(train.spam,train.email) | |
train.set = spam[train,] | |
test.set = spam[-train,] | |
rpart.spam = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini")) | |
# take a look at the decision rule | |
print(summary(rpart.spam)) | |
png("spam_tree.png", height=600, width=900) | |
# visualize it (gets difficult for bigger trees) | |
post(rpart.spam, filename='') | |
dev.off() | |
# predict the labels for the test set | |
predict.spam = predict(rpart.spam, test.set) | |
plabels.spam = colnames(predict.spam)[apply(predict.spam, 1, which.max)] | |
# compute the various measures of accuracy | |
classification.summary = function(plabels, tlabels) { | |
# true positives: things we labelled spam that are spam | |
TP = sum((plabels.spam == 'spam') * (tlabels == 'spam')) | |
# false positives: things we labelled spam that are email | |
FP = sum((plabels.spam == 'spam') * (tlabels == 'email')) | |
# true negatives: things we labelled email that are email | |
TN = sum((plabels.spam == 'email') * (tlabels == 'email')) | |
# false negatives: things we labelled email that are spam | |
FN = sum((plabels.spam == 'email') * (tlabels == 'spam')) | |
# accuracy | |
A = (TP+TN) / (TP+TN+FP+FN) | |
# sensitivity | |
sens = TP / (TP+FN) | |
# specificity | |
spec = TN / (TN+FP) | |
# precision | |
prec = TP / (TP+FN) | |
# confusion matrix | |
C = matrix(c(TP,FP,FN,TN),2,2) | |
colnames(C) = c('predicted spam', 'predicted email') | |
rownames(C) = c('truly spam', 'truly email') | |
return(list(A=A,TP=TP,FP=FP,TN=TN,FN=FN,C=C,sens=sens,spec=spec)) | |
} | |
s = classification.summary(plabels.spam, test.set$spam) | |
print(s) | |
png("spam_cptree.png", height=1200, width=800) | |
# you can control some aspects of the tree building process | |
# with rpart.control | |
rpart.spam.deeper = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini"), control=rpart.control(cp=0.00001, xval=20)) | |
post(rpart.spam, filename='') | |
dev.off() | |
# let's look at the stability of the tree | |
png("spam_repeat0.png", height=600, width=600) | |
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) | |
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) | |
train = c(train.spam,train.email) | |
train.set = spam[train,] | |
test.set = spam[-train,] | |
rpart.spam = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini")) | |
post(rpart.spam, filename='') | |
dev.off() | |
png("spam_repeat1.png", height=600, width=600) | |
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) | |
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) | |
train = c(train.spam,train.email) | |
train.set = spam[train,] | |
test.set = spam[-train,] | |
rpart.spam = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini")) | |
post(rpart.spam, filename='') | |
dev.off() | |
png("spam_repeat2.png", height=600, width=600) | |
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) | |
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) | |
train = c(train.spam,train.email) | |
train.set = spam[train,] | |
test.set = spam[-train,] | |
rpart.spam = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini")) | |
post(rpart.spam, filename='') | |
dev.off() | |
png("spam_repeat3.png", height=600, width=600) | |
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) | |
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) | |
train = c(train.spam,train.email) | |
train.set = spam[train,] | |
test.set = spam[-train,] | |
rpart.spam = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini")) | |
post(rpart.spam, filename='') | |
dev.off() | |
png("spam_repeat4.png", height=600, width=600) | |
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) | |
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) | |
train = c(train.spam,train.email) | |
train.set = spam[train,] | |
test.set = spam[-train,] | |
rpart.spam = rpart(spam ~ ., data=train.set, method="class", | |
parms=list(split="gini")) | |
post(rpart.spam, filename='') | |
dev.off() | |
png("spamROC.png", height=600, width=600) | |
predict.spam = predict(rpart.spam, test.set) | |
l = sort(unique(predict.spam[,'spam'])) | |
sens = c() | |
spec = c() | |
for (ll in l) { | |
plabels.spam = rep('email', nrow(predict.spam)) | |
plabels.spam[(predict.spam[,'spam'] >= ll)] = 'spam' | |
s = classification.summary(plabels.spam, test.set$spam) | |
sens = c(sens, s$sens) | |
spec = c(spec, s$spec) | |
} | |
sens = c(1,sens,0) | |
spec = c(0,spec,1) | |
plot(1-spec, sens, type='l', col='red', lwd=2) | |
abline(0,1,lwd=2, lty=2, col='blue') | |
dev.off() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment