Skip to content

Instantly share code, notes, and snippets.

@ixtel
Forked from bbengfort/classify.py
Created October 19, 2015 12:57
Show Gist options
  • Save ixtel/33f3ddfa76955047fc7a to your computer and use it in GitHub Desktop.
Save ixtel/33f3ddfa76955047fc7a to your computer and use it in GitHub Desktop.
Classifier handler for wrapping a pickled classifer object
#!/usr/bin/env python
import os
import sys
import time
import pickle
import argparse
import operator
import unicodecsv as csv
from string import punctuation
from nltk.corpus import stopwords
from nltk import wordpunct_tokenize
def timeit(func):
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
finit = time.time()
delta = finit - start
return result, delta
return wrapper
class OberonClassifier(object):
"""
Performs classification of products using a classifier that is loaded
via a pickle at runtime. This classifier can be of any type, but we
expect the Maximum Entropy classifier trained from the Oberon corpus.
"""
def __init__(self, pickle_path):
"""
Pass in the path of the pickle classifier object.
"""
with open(pickle_path, 'rb') as pkl:
self._classifier = pickle.load(pkl)
self.stopwords = stopwords.words('english')
def ispunct(self, token):
"""
Determines if a token is punctuation or not
"""
for char in token:
if char not in punctuation:
return False
return True
def tokenize(self, text):
"""
Tokenizes input removing punctuation, stopwords, and case.
"""
for token in wordpunct_tokenize(text):
if token in self.stopwords: continue
if self.ispunct(token): continue
yield token.lower()
def unigram_features(self, text):
"""
Returns unigram features from the text
"""
return dict((token, True) for token in self.tokenize(text))
def classify(self, text):
"""
Classifies the text using the internal classifier. Returns a
probability distribution of the labels associated with the text.
"""
features = self.unigram_features(text)
probdist = self._classifier.prob_classify(features)
labels = [(label, probdist.prob(label))
for label in probdist.samples()
if probdist.prob(label) > 0.01]
return sorted(labels, key=operator.itemgetter(1), reverse=True)
class BatchClassifier(OberonClassifier):
"""
Accepts as input a CSV with a field called "name", then classifies the
entire CSV and outputs a new csv with the old fields plus the label
and probability of the label (using maximum likelihood estimates).
"""
def __init__(self, pickle_path, input_csv):
"""
Pass in the path of the pickled classifier and the input csv
"""
super(BatchClassifier, self).__init__(pickle_path)
self._input_file = input_csv
self.reader = csv.DictReader(self._input_file)
self._fieldnames = reader.fieldnames
self._fieldnames.extend(['label', 'probability'])
if 'name' not in self._fieldnames:
raise Exception("Need a name field to classify in input CSV")
def __iter__(self):
"""
Basically a CSV reader that also performs classification.
"""
for row in self.reader:
labels = self.classify(row['name'])
row['label'] = labels[0][0]
row['probability'] = labels[0][1]
yield row
def batch_classify(self, output_file, verbose=1):
writer = csv.DictWriter(output_file, self._fieldnames)
for idx, row in enumerate(self):
writer.writerow(row)
if verbose > 0:
if idx % 10000 == 0:
print "%ik rows categorized" % (int(idx/1000))
return idx+1
def main(*argv):
# Command Line arguments
parse_config = {
"description": "Classifies input text or batch classifies a CSV file.",
"epilog": "You must specify the path to the pickled classifier, or it looks in the cwd.",
}
parser = argparse.ArgumentParser(**parse_config)
# Outfile argument
parser.add_argument('--write', dest='outfile', type=argparse.FileType('w'),
nargs='?', default=sys.stdout, metavar='FILE',
help='write results out to a file.')
# Infile argument
parser.add_argument('--batch', type=argparse.FileType('r'), nargs='?',
metavar="CSV", help='csv of names to batch classify.')
# Verbosity argument
parser.add_argument('--verbosity', type=int, choices=range(0,4), default=1,
metavar="INT", help="How verbose is output?")
# Classifier Pickle Path
parser.add_argument('--pickle', default='maxent.classifier.pickle', metavar="PATH",
help="Specify the path to the pickled classifier.")
# Text to classify if not batch
parser.add_argument('text', nargs="*", help="text to classify, surrounded by quotes.")
# Parse arguments
options = parser.parse_args()
@timeit
def batch_handler(options):
classifier = BatchClassifier(options.pickle, options.batch)
rows = classifier.batch_classify(options.outfile, options.verbosity)
return "%i rows classified" % rows
@timeit
def text_handler(options):
output = []
classifier = OberonClassifier(options.pickle)
for t in options.text:
output.append('"%s" is classified as:' % t)
for cls in classifier.classify(t):
output.append(" %s (%0.4f)" % cls)
return "\n".join(output)
if options.batch:
msg, time = batch_handler(options)
else:
msg, time = text_handler(options)
print msg
print "(%0.3f seconds)" % time
if __name__ == '__main__':
main()
PyYAML==3.10
nltk==2.0.4
numpy==1.8.0
python-dateutil==2.2
six==1.4.1
unicodecsv==0.9.4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment