Created
September 6, 2016 03:58
-
-
Save rnowling/659df4f92ae16aaac8b1ce20abd7b53b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script for comparing spam classification with a bag-of-words model constructed with and without hashing. You'll need to download a copy of the dataset from http://plg.uwaterloo.ca/~gvcormac/treccorpus07/about.html . | |
Copyright 2016 Ronald J. Nowling | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
from collections import defaultdict | |
from itertools import islice | |
import matplotlib.pyplot as plt | |
from sklearn.feature_extraction.text import HashingVectorizer, TfidfVectorizer | |
from sklearn.linear_model import SGDClassifier | |
from sklearn.metrics import roc_auc_score | |
DATA_DIR = "data" | |
FIGURES_DIR = "figures" | |
def _parse_message(message): | |
from bs4 import BeautifulSoup | |
body = "" | |
if message.is_multipart(): | |
for part in message.walk(): | |
ctype = part.get_content_type() | |
cdispo = str(part.get('Content-Disposition')) | |
# skip any attachments | |
if ctype == 'text/html' and 'attachment' not in cdispo: | |
body = part.get_payload(decode=True) | |
break | |
elif ctype == 'text/txt' and 'attachment' not in cdispo: | |
body = part.get_payload(decode=True) | |
break | |
# not multipart - i.e. plain text, no attachments, keeping fingers crossed | |
else: | |
body = message.get_payload(decode=True) | |
return message["To"], message["From"], BeautifulSoup(body, 'html.parser').get_text() | |
def stream_email(data_dir): | |
from email.parser import Parser | |
email_parser = Parser() | |
index_flname = data_dir + "/trec07p/full/index" | |
with open(index_flname) as index_fl: | |
for idx, ln in enumerate(index_fl): | |
category, email_fl_suffix = ln.strip().split() | |
if category == "ham": | |
label = 0 | |
elif category == "spam": | |
label = 1 | |
# strip .. prefix from path | |
email_flname = data_dir + "/trec07p" + email_fl_suffix[2:] | |
with open(email_flname) as email_fl: | |
message = email_parser.parse(email_fl) | |
to, from_, body = _parse_message(message) | |
yield (label, to, from_, body) | |
if __name__ == "__main__": | |
training_size = int(75419. * 0.75) # from Attenberg paper | |
stream = stream_email(DATA_DIR) | |
counts = defaultdict(int) | |
next_output = 1 | |
training_bodies = [] | |
training_labels = [] | |
testing_bodies = [] | |
testing_labels = [] | |
for idx, (label, to, from_, body) in enumerate(stream): | |
if idx < training_size: | |
training_bodies.append(body) | |
training_labels.append(label) | |
else: | |
testing_bodies.append(body) | |
testing_labels.append(label) | |
counts[label] += 1 | |
count = idx + 1 | |
if count == next_output: | |
print count, counts | |
next_output *= 2 | |
print count, counts | |
tfidf_vectorizer = TfidfVectorizer(binary=True, norm=None, use_idf=False) | |
tfidf_lr = SGDClassifier(loss="log", penalty="l2") | |
tfidf_training_features = tfidf_vectorizer.fit_transform(training_bodies) | |
n_tfidf_features = tfidf_training_features.shape[1] | |
tfidf_lr.fit(tfidf_training_features, training_labels) | |
tfidf_testing_features = tfidf_vectorizer.transform(testing_bodies) | |
tfidf_pred_probs = tfidf_lr.predict_proba(tfidf_testing_features) | |
tfidf_auc = roc_auc_score(testing_labels, tfidf_pred_probs[:, 1]) | |
print "tfidf auc", tfidf_auc, "n_features", n_tfidf_features | |
aucs = [] | |
nzs = [] | |
bit_range = list(range(8, 25)) | |
for n_bits in bit_range: | |
lr = SGDClassifier(loss="log", penalty="l2") | |
hashing_vectorizer = HashingVectorizer(n_features = 2 ** n_bits, binary=True, norm=None) | |
hashed_training_features = hashing_vectorizer.transform(training_bodies) | |
lr.fit(hashed_training_features, training_labels) | |
hashed_testing_features = hashing_vectorizer.transform(testing_bodies) | |
pred_probs = lr.predict_proba(hashed_testing_features) | |
aucs.append(roc_auc_score(testing_labels, pred_probs[:, 1])) | |
nzs.append((lr.coef_ != 0).sum()) | |
print n_bits, aucs[-1] | |
fig, ax1 = plt.subplots() | |
ax1.plot(bit_range, aucs, 'c-') | |
ax1.plot(bit_range, [tfidf_auc] * len(bit_range), 'c--', label="Tfidf") | |
ax1.set_xlabel('Hashed Features (log_2)', fontsize=16) | |
# Make the y-axis label and tick labels match the line color. | |
ax1.set_ylabel('AUC', color='c', fontsize=16) | |
for tl in ax1.get_yticklabels(): | |
tl.set_color('c') | |
ax2 = ax1.twinx() | |
ax2.plot(bit_range, nzs, 'k-') | |
ax2.plot(bit_range, [n_tfidf_features] * len(bit_range), 'k--') | |
ax2.set_ylabel('Non-zero Weights', color='k', fontsize=16) | |
for tl in ax2.get_yticklabels(): | |
tl.set_color('k') | |
fig.subplots_adjust(right=0.8) | |
fig.savefig(FIGURES_DIR + "/hashed_features_auc_weights.png", DPI=200) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment