-
-
Save ixtel/b39781cf3a4700ac9f7c to your computer and use it in GitHub Desktop.
A quick category classifier with NLTK
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import nltk.classify.util | |
from itertools import chain, imap | |
from string import punctuation | |
from nltk.corpus import stopwords | |
from nltk import wordpunct_tokenize | |
from nltk.classify import NaiveBayesClassifier | |
def tokenize(s): | |
for word in wordpunct_tokenize(s): | |
if not word in stopwords.words('english') and word not in punctuation: | |
yield word | |
def wordfeature(words): | |
""" | |
Create a Bag of Words feature set | |
""" | |
return dict((word, True) for word in words) | |
class CategoryModel(object): | |
""" | |
Expects a dataset of the following type: | |
category product name | |
category product name | |
And will create a feature set based of these categories. | |
""" | |
def __init__(self, dataset=None): | |
self.dataset = dataset | |
self._features = None | |
self._classifier = None | |
@property | |
def features(self): | |
if not self._features: | |
self._features = [] | |
with open(self.dataset, 'r') as data: | |
for line in data: | |
line = line.strip() | |
cat, name = line.split('\t') | |
self._features.append((wordfeature(tokenize(name)), cat)) | |
return self._features | |
@property | |
def classifier(self): | |
if not self._classifier: | |
self._classifier = NaiveBayesClassifier.train(self.features) | |
return self._classifier | |
@property | |
def accuracy(self): | |
cutoff = len(self.features) * 3/4 | |
training = self.features[:cutoff] | |
testing = self.features[cutoff:] | |
classifier = NaiveBayesClassifier.train(training) | |
return nltk.classify.util.accuracy(classifier, testing) | |
def show_most_informative_features(self): | |
return self.classifier.show_most_informative_features() | |
def analyze(self, name): | |
return self.classifier.classify(wordfeature(tokenize(name))) | |
def __str__(self): | |
output = [] | |
output.append("Trained on %d instances, tested on %d instances") | |
output.append("Classifier Accuracy: %0.3f") | |
output.append("") | |
return "\n".join(output) % (len(self.features) * 3/4, len(self.features) * 1/4, self.accuracy) | |
if __name__ == '__main__': | |
model = CategoryModel("prods.tsv") | |
print model | |
model.show_most_informative_features() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment