Skip to content

Instantly share code, notes, and snippets.

@ixtel
Forked from bbengfort/classifier.py
Created October 19, 2015 12:57
Show Gist options
  • Save ixtel/b39781cf3a4700ac9f7c to your computer and use it in GitHub Desktop.
Save ixtel/b39781cf3a4700ac9f7c to your computer and use it in GitHub Desktop.
A quick category classifier with NLTK
#!/usr/bin/env python
import nltk.classify.util
from itertools import chain, imap
from string import punctuation
from nltk.corpus import stopwords
from nltk import wordpunct_tokenize
from nltk.classify import NaiveBayesClassifier
def tokenize(s):
for word in wordpunct_tokenize(s):
if not word in stopwords.words('english') and word not in punctuation:
yield word
def wordfeature(words):
"""
Create a Bag of Words feature set
"""
return dict((word, True) for word in words)
class CategoryModel(object):
"""
Expects a dataset of the following type:
category product name
category product name
And will create a feature set based of these categories.
"""
def __init__(self, dataset=None):
self.dataset = dataset
self._features = None
self._classifier = None
@property
def features(self):
if not self._features:
self._features = []
with open(self.dataset, 'r') as data:
for line in data:
line = line.strip()
cat, name = line.split('\t')
self._features.append((wordfeature(tokenize(name)), cat))
return self._features
@property
def classifier(self):
if not self._classifier:
self._classifier = NaiveBayesClassifier.train(self.features)
return self._classifier
@property
def accuracy(self):
cutoff = len(self.features) * 3/4
training = self.features[:cutoff]
testing = self.features[cutoff:]
classifier = NaiveBayesClassifier.train(training)
return nltk.classify.util.accuracy(classifier, testing)
def show_most_informative_features(self):
return self.classifier.show_most_informative_features()
def analyze(self, name):
return self.classifier.classify(wordfeature(tokenize(name)))
def __str__(self):
output = []
output.append("Trained on %d instances, tested on %d instances")
output.append("Classifier Accuracy: %0.3f")
output.append("")
return "\n".join(output) % (len(self.features) * 3/4, len(self.features) * 1/4, self.accuracy)
if __name__ == '__main__':
model = CategoryModel("prods.tsv")
print model
model.show_most_informative_features()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment