Skip to content

Instantly share code, notes, and snippets.

@vijayanandrp
Last active October 31, 2017 03:18
Show Gist options
  • Save vijayanandrp/38f2c23186e13babde4d912f74392c8c to your computer and use it in GitHub Desktop.
Save vijayanandrp/38f2c23186e13babde4d912f74392c8c to your computer and use it in GitHub Desktop.
Topic Classification by Linear classifiers (SVM, logistic regression, a.o.) with SGD training. Simple example.
#!/usr/bin/env python3
# encoding: utf-8
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn import metrics
import numpy as np
from pprint import pprint
# Loading only selected category
categories = ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']
# 1. Loading the data set
train_set = fetch_20newsgroups(categories=categories, shuffle=True, random_state=42, subset='train')
test_set = fetch_20newsgroups(categories=categories, shuffle=True, random_state=42, subset='test')
# 2. Converting text and numbers (matrix)
# 3. training in ml (all can be written in one line)
text_clf = Pipeline([('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', SGDClassifier(loss='hinge', penalty='l2',
alpha=1e-3, random_state=42,
max_iter=5, tol=None))])
text_clf.fit(train_set.data, train_set.target)
docs_test = test_set.data
predicted = text_clf.predict(docs_test)
accuracy = np.mean(predicted == test_set.target)
# Linear classifiers (SVM, logistic regression, a.o.) with SGD training.
print('Accuracy of SGDClassifier (support vector machine - SVM) - {}'.format(accuracy * 100))
# 4. Tuning training parameters using Grid Search for feature extraction and ml algorithm
parameters = {'vect__ngram_range': [(1, 1), (1, 2)],
'tfidf__use_idf': (True, False),
'clf__alpha': (1e-2, 1e-3)}
gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)
gs_clf.fit(train_set.data, train_set.target)
gs_predicted = gs_clf.predict(docs_test)
accuracy = np.mean(gs_predicted == test_set.target)
print('Accuracy (after tuning) of SGDClassifier (support vector machine - SVM) - {}'.format(accuracy * 100))
print('Grid Search best score -')
print(gs_clf.best_score_)
print('Grid Search best parameters -')
pprint(gs_clf.best_params_)
print('Metrics classification report ')
print(metrics.classification_report(test_set.target, predicted, target_names=test_set.target_names))
print('Metric Confusion matrix')
print(metrics.confusion_matrix(test_set.target, predicted))
@vijayanandrp
Copy link
Author

`Accuracy of SGDClassifier (support vector machine - SVM) - 91.27829560585884
Accuracy (after tuning) of SGDClassifier (support vector machine - SVM) - 91.27829560585884
Grid Search best score -
0.965440850687
Grid Search best parameters -
{'clf__alpha': 0.001, 'tfidf__use_idf': True, 'vect__ngram_range': (1, 1)}
Metrics classification report
precision recall f1-score support

       alt.atheism       0.95      0.81      0.87       319
     comp.graphics       0.88      0.97      0.92       389
           sci.med       0.94      0.90      0.92       396

soc.religion.christian 0.90 0.95 0.93 398

       avg / total       0.92      0.91      0.91      1502

Metric Confusion matrix
[[258 11 15 35]
[ 4 379 3 3]
[ 5 33 355 3]
[ 5 10 4 379]]`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment