Last active
July 8, 2020 15:07
-
-
Save andreasvc/e363fb134e5a38add68acb3558517ba0 to your computer and use it in GitHub Desktop.
A baseline Bag-of-Words text classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A baseline Bag-of-Words text classification. | |
Usage: python3 classify.py <train.txt> <test.txt> [--svm] [--tfidf] [--bigrams] | |
train.txt and test.txt should contain one "document" per line, | |
first token should be the label. | |
The default is to use regularized Logistic Regression and relative frequencies. | |
Pass --svm to use Linear SVM instead. | |
Pass --tfidf to use tf-idf instead of relative frequencies. | |
Pass --bigrams to use bigrams instead of unigrams. | |
""" | |
import sys | |
import getopt | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.linear_model import LogisticRegressionCV | |
from sklearn.svm import LinearSVC | |
from sklearn.pipeline import Pipeline | |
from sklearn.metrics import classification_report, confusion_matrix | |
def readcorpus(corpusfile): | |
documents = [] | |
labels = [] | |
with open(corpusfile, encoding='utf8') as inp: | |
for line in inp: | |
label, doc = line.strip().split(None, 1) | |
documents.append(doc) | |
labels.append(label) | |
return documents, labels | |
def main(): | |
# Command line interface | |
try: | |
opts, args = getopt.gnu_getopt( | |
sys.argv[1:], '', ['svm', 'tfidf', 'bigrams']) | |
opts = dict(opts) | |
train, test = args | |
except (getopt.GetoptError, IndexError, ValueError) as err: | |
print(err) | |
print(__doc__) | |
return | |
# read train and test corpus | |
Xtrain, Ytrain = readcorpus(train) | |
Xtest, Ytest = readcorpus(test) | |
# Bag-of-Words extraction | |
vec = TfidfVectorizer( | |
use_idf='--tfidf' in opts, | |
ngram_range=(2, 2) if '--bigrams' in opts else (1, 1), | |
lowercase=True, | |
max_features=100000, | |
binary=False) | |
# choose classifier | |
if '--svm' in opts: | |
# With LinearSVC you have to specify the regularization parameter C | |
clf = LinearSVC(C=1.0) | |
else: | |
# LogisticRegressionCV automatically picks the best regularization | |
# parameter using cross validation. | |
clf = LogisticRegressionCV( | |
cv=3, | |
class_weight='balanced', | |
max_iter=100) | |
# combine the vectorizer with a classifier | |
classifier = Pipeline([ | |
('vec', vec), | |
('clf', clf)]) | |
# train the classifier | |
classifier.fit(Xtrain, Ytrain) | |
# make predictions on test set | |
Yguess = classifier.predict(Xtest) | |
# evaluate | |
print('confusion matrix:\n', confusion_matrix(Ytest, Yguess)) | |
print(classification_report(Ytest, Yguess)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment