Skip to content

Instantly share code, notes, and snippets.

@vijayanandrp
Last active October 31, 2017 03:38
Show Gist options
  • Save vijayanandrp/8a1c6af2477d8c9d6c1889370452e74f to your computer and use it in GitHub Desktop.
Save vijayanandrp/8a1c6af2477d8c9d6c1889370452e74f to your computer and use it in GitHub Desktop.
Topic Classification using naive bayes with simple example.
#!/usr/bin/env python3
# encoding: utf-8
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn import metrics
import numpy as np
from pprint import pprint
# Loading only selected category
categories = ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']
# 1. Loading the data set
train_set = fetch_20newsgroups(categories=categories, shuffle=True, random_state=55, subset='train')
test_set = fetch_20newsgroups(categories=categories, shuffle=True, random_state=42, subset='test')
# 2. Converting text and numbers (matrix)
# 3. training in ml (all can be written in one line)
text_clf = Pipeline([('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', MultinomialNB())])
text_clf.fit(train_set.data, train_set.target)
docs_test = test_set.data
predicted = text_clf.predict(docs_test)
accuracy = np.mean(predicted == test_set.target)
print('\nAccuracy of MultinomialNB (naive Bayes) - {}'.format(accuracy * 100))
# 4. Auto-tuning the training parameters using Grid Search for both feature extraction and classifier
parameters = {'vect__ngram_range': [(1, 1), (1, 2)],
'tfidf__use_idf': (True, False),
'clf__fit_prior': (True, False),
'clf__alpha': (0.5, 1.0)}
gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)
gs_clf.fit(train_set.data, train_set.target)
gs_predicted = gs_clf.predict(docs_test)
accuracy = np.mean(gs_predicted == test_set.target)
print('\nAccuracy (after tuning) of MultinomialNB (naive Bayes) - {}'.format(accuracy * 100))
print('\nGrid Search best score -')
print(gs_clf.best_score_)
print('\nGrid Search best parameters -')
pprint(gs_clf.best_params_)
print('\nMetrics classification report ')
print(metrics.classification_report(test_set.target, predicted, target_names=test_set.target_names))
print('\nMetric Confusion matrix')
print(metrics.confusion_matrix(test_set.target, predicted))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment