Created
March 5, 2016 17:51
-
-
Save panggi/50b184e079d8fc5d7ad3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Imports | |
| import pandas as pd | |
| import os, sys, getopt, cPickle, csv, sklearn | |
| from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer | |
| from sklearn.naive_bayes import MultinomialNB | |
| from sklearn.svm import SVC, LinearSVC | |
| from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.grid_search import GridSearchCV | |
| from sklearn.cross_validation import StratifiedKFold, cross_val_score, train_test_split | |
| from textblob import TextBlob | |
| # Dataset | |
| MESSAGES = pd.read_csv('./data/SMSSpamCollection', sep='\t', quoting=csv.QUOTE_NONE, names=["label", "message"]) | |
| # Preprocessing | |
| def tokens(message): | |
| message = unicode(message, 'utf8') | |
| return TextBlob(message).words | |
| def lemmas(message): | |
| message = unicode(message, 'utf8').lower() | |
| words = TextBlob(message).words | |
| return [word.lemma for word in words] | |
| # Training | |
| def train_multinomial_nb(messages): | |
| # split dataset for cross validation | |
| msg_train, msg_test, label_train, label_test = train_test_split(messages['message'], messages['label'], test_size=0.2) | |
| # create pipeline | |
| pipeline = Pipeline([('bow', CountVectorizer(analyzer=lemmas)),('tfidf', TfidfTransformer()),('classifier', MultinomialNB())]) | |
| # pipeline parameters to automatically explore and tune | |
| params = { | |
| 'tfidf__use_idf': (True, False), | |
| 'bow__analyzer': (lemmas, tokens), | |
| } | |
| grid = GridSearchCV( | |
| pipeline, | |
| params, # parameters to tune via cross validation | |
| refit=True, # fit using all data, on the best detected classifier | |
| n_jobs=-1, | |
| scoring='accuracy', | |
| cv=StratifiedKFold(label_train, n_folds=5), | |
| ) | |
| # train | |
| nb_detector = grid.fit(msg_train, label_train) | |
| print "" | |
| predictions = nb_detector.predict(msg_test) | |
| print ":: Confusion Matrix" | |
| print "" | |
| print confusion_matrix(label_test, predictions) | |
| print "" | |
| print ":: Classification Report" | |
| print "" | |
| print classification_report(label_test, predictions) | |
| # save model to pickle file | |
| file_name = 'models/sms_spam_nb_model.pkl' | |
| with open(file_name, 'wb') as fout: | |
| cPickle.dump(nb_detector, fout) | |
| print 'model written to: ' + file_name | |
| def train_svm(messages): | |
| # split dataset for cross validation | |
| msg_train, msg_test, label_train, label_test = train_test_split(messages['message'], messages['label'], test_size=0.2) | |
| # create pipeline | |
| pipeline = Pipeline([('bow', CountVectorizer(analyzer=lemmas)),('tfidf', TfidfTransformer()),('classifier', SVC())]) | |
| # pipeline parameters to automatically explore and tune | |
| params = [ | |
| {'classifier__C': [1, 10, 100, 1000], 'classifier__kernel': ['linear']}, | |
| {'classifier__C': [1, 10, 100, 1000], 'classifier__gamma': [0.001, 0.0001], 'classifier__kernel': ['rbf']}, | |
| ] | |
| grid = GridSearchCV( | |
| pipeline, | |
| param_grid=params, # parameters to tune via cross validation | |
| refit=True, # fit using all data, on the best detected classifier | |
| n_jobs=-1, | |
| scoring='accuracy', | |
| cv=StratifiedKFold(label_train, n_folds=5), | |
| ) | |
| # train | |
| svm_detector = grid.fit(msg_train, label_train) | |
| print "" | |
| print ":: Confusion Matrix" | |
| print "" | |
| print confusion_matrix(label_test, svm_detector.predict(msg_test)) | |
| print "" | |
| print ":: Classification Report" | |
| print "" | |
| print classification_report(label_test, svm_detector.predict(msg_test)) | |
| # save model to pickle file | |
| file_name = 'models/sms_spam_svm_model.pkl' | |
| with open(file_name, 'wb') as fout: | |
| cPickle.dump(svm_detector, fout) | |
| print 'model written to: ' + file_name | |
| def main(argv): | |
| # check if models exist, if not run training | |
| if(os.path.isfile('models/sms_spam_nb_model.pkl') == False): | |
| print "" | |
| print "Creating Naive Bayes Model....." | |
| train_multinomial_nb(MESSAGES) | |
| if(os.path.isfile('models/sms_spam_svm_model.pkl') == False): | |
| print "" | |
| print "Creating SVM Model....." | |
| train_svm(MESSAGES) | |
| inputmessage = '' | |
| try: | |
| opts, args = getopt.getopt(argv,"hm:",["message="]) | |
| except getopt.GetoptError: | |
| print 'sms-spam-detector.py -m <message string>' | |
| sys.exit(2) | |
| for opt, arg in opts: | |
| if opt == '-h': | |
| print 'sms-spam-detector.py -m <message string>' | |
| sys.exit() | |
| elif opt in ("-m", "--message"): | |
| prediction = predict(arg) | |
| print 'This message is predicted by', prediction | |
| def predict(message): | |
| nb_detector = cPickle.load(open('models/sms_spam_nb_model.pkl')) | |
| svm_detector = cPickle.load(open('models/sms_spam_svm_model.pkl')) | |
| nb_predict = nb_detector.predict([message])[0] | |
| svm_predict = svm_detector.predict([message])[0] | |
| return 'SVM as ' + svm_predict + ' and Naive Bayes as ' + nb_predict | |
| if __name__ == "__main__": | |
| main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment