Last active
August 29, 2015 14:05
-
-
Save rokroskar/89d85334f565a25a9960 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import marisa_trie | |
from sklearn.feature_extraction.text import CountVectorizer, _make_int_array | |
import numpy as np | |
import scipy.sparse as sp | |
from itertools import chain | |
class MarisaCountVectorizer(CountVectorizer): | |
""" | |
Extension of Scikit-learn CountVectorizer class using the | |
MARISA-trie python wrapper from https://github.com/kmike/marisa-trie | |
Inspired by http://blog.scrapinghub.com/2014/03/26/optimizing-memory-usage-of-scikit-learn-models-using-succinct-tries/ | |
Difference from the implentation linked above is that here we build the | |
vocabulary directly instead of calling the CountVectorizer method, which | |
still requires an expensive dictionary to be built. We also use generators | |
for returning the ngrams to further save on memory usage. | |
Using the test script from https://gist.github.com/kmike/7813450: | |
for unigrams, the memory usage is actually worse than CountVectorizer, but | |
for larger n-grams it's significantly better. For 8-grams the max memory consumption | |
is lower by factor of ~4, though the runtime is longer due to slower marisa-trie | |
key lookup: | |
[public-docking-rz-0720:~/homegrown] rokstar% python memusage_fit.py marisa_count8 | |
fit time: 38.0s | |
fit memusage: 610.3MB | |
dump time: 0.2s | |
dump memusage: 390.5MB | |
[public-docking-rz-0720:~/homegrown] rokstar% python memusage_fit.py count8 | |
fit time: 31.9s | |
fit memusage: 1299.7MB | |
dump time: 125.5s | |
dump memusage: 1839.1MB | |
""" | |
def _word_ngrams(self, tokens, stop_words=None): | |
"""Turn tokens into a sequence of n-grams after stop words filtering""" | |
# handle stop words | |
if stop_words is not None: | |
tokens = [w for w in tokens if w not in stop_words] | |
# handle token n-grams | |
min_n, max_n = self.ngram_range | |
n_tokens = len(tokens) | |
for n in xrange(min_n, | |
min(max_n + 1, n_tokens + 1)): | |
for i in xrange(n_tokens - n + 1): | |
yield " ".join(tokens[i: i+n]) | |
def _generate_vocab(self,raw_documents,verbose=False) : | |
import time | |
# analyzer for extracting features | |
analyze = self.build_analyzer() | |
itime = time.time() | |
# get all the features and make the vocabulary trie | |
itime = time.time() | |
# make a list of generators | |
feature_generator = map(analyze,raw_documents) | |
vocabulary = marisa_trie.Trie(chain.from_iterable(feature_generator)) | |
if verbose: print 'building the trie took %f s'%(time.time()-itime) | |
return vocabulary | |
def _count_vocab(self, raw_documents, fixed_vocab, verbose=False) : | |
"""Override the CountVectorizer._count_vocab method to avoid | |
building a dictionary and instead use the Trie directly | |
""" | |
import time | |
# analyzer for extracting features | |
analyze = self.build_analyzer() | |
j_indices = _make_int_array() | |
indptr = _make_int_array() | |
indptr.append(0) | |
lookup_times = 0.0 | |
start_time = time.time() | |
for i,doc in enumerate(raw_documents) : | |
for feature in analyze(doc) : | |
itime = time.time() | |
if feature in self.vocabulary_ : | |
j_indices.append(self.vocabulary_[feature]) | |
lookup_times+=(time.time()-itime) | |
indptr.append(len(j_indices)) | |
if verbose: | |
print 'lookup times took %f s'%(lookup_times) | |
print 'total count time %f s'%(time.time()-start_time) | |
# some Python/Scipy versions won't accept an array.array: | |
if j_indices: | |
j_indices = np.frombuffer(j_indices, dtype=np.intc) | |
else: | |
j_indices = np.array([], dtype=np.int32) | |
indptr = np.frombuffer(indptr, dtype=np.intc) | |
values = np.ones(len(j_indices)) | |
X = sp.csr_matrix((values, j_indices, indptr), | |
shape=(len(indptr) - 1, len(self.vocabulary_)), | |
dtype=self.dtype) | |
X.sum_duplicates() | |
self.fixed_vocabulary = True | |
return self.vocabulary_, X | |
def fit_transform(self, raw_documents) : | |
self.vocabulary_ = self._generate_vocab(raw_documents) | |
super(MarisaCountVectorizer,self).fit_transform(raw_documents) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment