Skip to content

Instantly share code, notes, and snippets.

@ArionHardison
Created April 30, 2013 19:02
Show Gist options
  • Save ArionHardison/5491068 to your computer and use it in GitHub Desktop.
Save ArionHardison/5491068 to your computer and use it in GitHub Desktop.
# package for trees
library(rpart)
# package including data from Elements of Statistical Learning
library(ElemStatLearn)
data(spam)
# make response a 0-1 outcome
#spam$spam = ifelse(spam$spam=="spam",1,0)
spam.sub = c(1:nrow(spam))[spam$spam == 'spam']
nospam.sub = c(1:nrow(spam))[spam$spam == 'email']
# use 2/3 for training, 1/3 for test
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3))
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3))
train = c(train.spam,train.email)
train.set = spam[train,]
test.set = spam[-train,]
rpart.spam = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"))
# take a look at the decision rule
print(summary(rpart.spam))
png("spam_tree.png", height=600, width=900)
# visualize it (gets difficult for bigger trees)
post(rpart.spam, filename='')
dev.off()
# predict the labels for the test set
predict.spam = predict(rpart.spam, test.set)
plabels.spam = colnames(predict.spam)[apply(predict.spam, 1, which.max)]
# compute the various measures of accuracy
classification.summary = function(plabels, tlabels) {
# true positives: things we labelled spam that are spam
TP = sum((plabels.spam == 'spam') * (tlabels == 'spam'))
# false positives: things we labelled spam that are email
FP = sum((plabels.spam == 'spam') * (tlabels == 'email'))
# true negatives: things we labelled email that are email
TN = sum((plabels.spam == 'email') * (tlabels == 'email'))
# false negatives: things we labelled email that are spam
FN = sum((plabels.spam == 'email') * (tlabels == 'spam'))
# accuracy
A = (TP+TN) / (TP+TN+FP+FN)
# sensitivity
sens = TP / (TP+FN)
# specificity
spec = TN / (TN+FP)
# precision
prec = TP / (TP+FN)
# confusion matrix
C = matrix(c(TP,FP,FN,TN),2,2)
colnames(C) = c('predicted spam', 'predicted email')
rownames(C) = c('truly spam', 'truly email')
return(list(A=A,TP=TP,FP=FP,TN=TN,FN=FN,C=C,sens=sens,spec=spec))
}
s = classification.summary(plabels.spam, test.set$spam)
print(s)
png("spam_cptree.png", height=1200, width=800)
# you can control some aspects of the tree building process
# with rpart.control
rpart.spam.deeper = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"), control=rpart.control(cp=0.00001, xval=20))
post(rpart.spam, filename='')
dev.off()
# let's look at the stability of the tree
png("spam_repeat0.png", height=600, width=600)
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3))
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3))
train = c(train.spam,train.email)
train.set = spam[train,]
test.set = spam[-train,]
rpart.spam = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"))
post(rpart.spam, filename='')
dev.off()
png("spam_repeat1.png", height=600, width=600)
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3))
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3))
train = c(train.spam,train.email)
train.set = spam[train,]
test.set = spam[-train,]
rpart.spam = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"))
post(rpart.spam, filename='')
dev.off()
png("spam_repeat2.png", height=600, width=600)
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3))
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3))
train = c(train.spam,train.email)
train.set = spam[train,]
test.set = spam[-train,]
rpart.spam = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"))
post(rpart.spam, filename='')
dev.off()
png("spam_repeat3.png", height=600, width=600)
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3))
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3))
train = c(train.spam,train.email)
train.set = spam[train,]
test.set = spam[-train,]
rpart.spam = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"))
post(rpart.spam, filename='')
dev.off()
png("spam_repeat4.png", height=600, width=600)
train.spam = sample(spam.sub,floor(length(spam.sub)*2/3))
train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3))
train = c(train.spam,train.email)
train.set = spam[train,]
test.set = spam[-train,]
rpart.spam = rpart(spam ~ ., data=train.set, method="class",
parms=list(split="gini"))
post(rpart.spam, filename='')
dev.off()
png("spamROC.png", height=600, width=600)
predict.spam = predict(rpart.spam, test.set)
l = sort(unique(predict.spam[,'spam']))
sens = c()
spec = c()
for (ll in l) {
plabels.spam = rep('email', nrow(predict.spam))
plabels.spam[(predict.spam[,'spam'] >= ll)] = 'spam'
s = classification.summary(plabels.spam, test.set$spam)
sens = c(sens, s$sens)
spec = c(spec, s$spec)
}
sens = c(1,sens,0)
spec = c(0,spec,1)
plot(1-spec, sens, type='l', col='red', lwd=2)
abline(0,1,lwd=2, lty=2, col='blue')
dev.off()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment