Last active
October 19, 2015 12:57
-
-
Save bbengfort/9669026 to your computer and use it in GitHub Desktop.
Classifier handler for wrapping a pickled classifer object
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import sys | |
import time | |
import pickle | |
import argparse | |
import operator | |
import unicodecsv as csv | |
from string import punctuation | |
from nltk.corpus import stopwords | |
from nltk import wordpunct_tokenize | |
def timeit(func): | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
result = func(*args, **kwargs) | |
finit = time.time() | |
delta = finit - start | |
return result, delta | |
return wrapper | |
class OberonClassifier(object): | |
""" | |
Performs classification of products using a classifier that is loaded | |
via a pickle at runtime. This classifier can be of any type, but we | |
expect the Maximum Entropy classifier trained from the Oberon corpus. | |
""" | |
def __init__(self, pickle_path): | |
""" | |
Pass in the path of the pickle classifier object. | |
""" | |
with open(pickle_path, 'rb') as pkl: | |
self._classifier = pickle.load(pkl) | |
self.stopwords = stopwords.words('english') | |
def ispunct(self, token): | |
""" | |
Determines if a token is punctuation or not | |
""" | |
for char in token: | |
if char not in punctuation: | |
return False | |
return True | |
def tokenize(self, text): | |
""" | |
Tokenizes input removing punctuation, stopwords, and case. | |
""" | |
for token in wordpunct_tokenize(text): | |
if token in self.stopwords: continue | |
if self.ispunct(token): continue | |
yield token.lower() | |
def unigram_features(self, text): | |
""" | |
Returns unigram features from the text | |
""" | |
return dict((token, True) for token in self.tokenize(text)) | |
def classify(self, text): | |
""" | |
Classifies the text using the internal classifier. Returns a | |
probability distribution of the labels associated with the text. | |
""" | |
features = self.unigram_features(text) | |
probdist = self._classifier.prob_classify(features) | |
labels = [(label, probdist.prob(label)) | |
for label in probdist.samples() | |
if probdist.prob(label) > 0.01] | |
return sorted(labels, key=operator.itemgetter(1), reverse=True) | |
class BatchClassifier(OberonClassifier): | |
""" | |
Accepts as input a CSV with a field called "name", then classifies the | |
entire CSV and outputs a new csv with the old fields plus the label | |
and probability of the label (using maximum likelihood estimates). | |
""" | |
def __init__(self, pickle_path, input_csv): | |
""" | |
Pass in the path of the pickled classifier and the input csv | |
""" | |
super(BatchClassifier, self).__init__(pickle_path) | |
self._input_file = input_csv | |
self.reader = csv.DictReader(self._input_file) | |
self._fieldnames = reader.fieldnames | |
self._fieldnames.extend(['label', 'probability']) | |
if 'name' not in self._fieldnames: | |
raise Exception("Need a name field to classify in input CSV") | |
def __iter__(self): | |
""" | |
Basically a CSV reader that also performs classification. | |
""" | |
for row in self.reader: | |
labels = self.classify(row['name']) | |
row['label'] = labels[0][0] | |
row['probability'] = labels[0][1] | |
yield row | |
def batch_classify(self, output_file, verbose=1): | |
writer = csv.DictWriter(output_file, self._fieldnames) | |
for idx, row in enumerate(self): | |
writer.writerow(row) | |
if verbose > 0: | |
if idx % 10000 == 0: | |
print "%ik rows categorized" % (int(idx/1000)) | |
return idx+1 | |
def main(*argv): | |
# Command Line arguments | |
parse_config = { | |
"description": "Classifies input text or batch classifies a CSV file.", | |
"epilog": "You must specify the path to the pickled classifier, or it looks in the cwd.", | |
} | |
parser = argparse.ArgumentParser(**parse_config) | |
# Outfile argument | |
parser.add_argument('--write', dest='outfile', type=argparse.FileType('w'), | |
nargs='?', default=sys.stdout, metavar='FILE', | |
help='write results out to a file.') | |
# Infile argument | |
parser.add_argument('--batch', type=argparse.FileType('r'), nargs='?', | |
metavar="CSV", help='csv of names to batch classify.') | |
# Verbosity argument | |
parser.add_argument('--verbosity', type=int, choices=range(0,4), default=1, | |
metavar="INT", help="How verbose is output?") | |
# Classifier Pickle Path | |
parser.add_argument('--pickle', default='maxent.classifier.pickle', metavar="PATH", | |
help="Specify the path to the pickled classifier.") | |
# Text to classify if not batch | |
parser.add_argument('text', nargs="*", help="text to classify, surrounded by quotes.") | |
# Parse arguments | |
options = parser.parse_args() | |
@timeit | |
def batch_handler(options): | |
classifier = BatchClassifier(options.pickle, options.batch) | |
rows = classifier.batch_classify(options.outfile, options.verbosity) | |
return "%i rows classified" % rows | |
@timeit | |
def text_handler(options): | |
output = [] | |
classifier = OberonClassifier(options.pickle) | |
for t in options.text: | |
output.append('"%s" is classified as:' % t) | |
for cls in classifier.classify(t): | |
output.append(" %s (%0.4f)" % cls) | |
return "\n".join(output) | |
if options.batch: | |
msg, time = batch_handler(options) | |
else: | |
msg, time = text_handler(options) | |
print msg | |
print "(%0.3f seconds)" % time | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PyYAML==3.10 | |
nltk==2.0.4 | |
numpy==1.8.0 | |
python-dateutil==2.2 | |
six==1.4.1 | |
unicodecsv==0.9.4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment