Created
April 11, 2016 19:06
-
-
Save metasyn/5e954e45ea6e4d5d2e2dbd46cbba5f6d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import sys | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.decomposition import LatentDirichletAllocation | |
| import pandas as pd | |
| import numpy as np | |
| from command_util import convert_params | |
| from base import * | |
| class LDA(BaseMixin): | |
| def __init__(self, options): | |
| self.handle_options(options) | |
| out_params = convert_params( | |
| options.get('params', {}), | |
| ints=[ | |
| 'n_top_words', | |
| 'n_features', | |
| 'n_topics', | |
| 'random_state' | |
| ], | |
| floats=[ | |
| 'doc_topic_prior', | |
| 'topic_word_prior', | |
| 'learning_decay', | |
| 'learning_offset' | |
| ], | |
| aliases={ | |
| } | |
| ) | |
| # We don't want to actually pass n_top_words/n_features to the model | |
| # because its not a valid parameter | |
| if 'n_top_words' in out_params: | |
| self.n_top_words = out_params.pop('n_top_words') | |
| else: | |
| self.n_top_words = 30 # default | |
| if 'n_features' in out_params: | |
| self.n_features = out_params.pop('n_features') | |
| else: | |
| self.n_features = 100 # default same as TFIDF | |
| self.estimator = LatentDirichletAllocation(**out_params) | |
| # We need the first argument as our variable | |
| self.variable = options['variables'][0] | |
| def fit(self, X): | |
| self.assert_field_present(X, self.variable) | |
| self.drop_unused_fields(X, [self.variable]) | |
| self.drop_na_rows(X) | |
| self.assert_any_fields(X) | |
| self.assert_any_rows(X) | |
| if type(X[self.variable][0]) != str: | |
| raise RuntimeError('Invalid type: "%s" is of type %s. String expected.' % (self.variable, | |
| X[self.variable].dtype)) | |
| # get the term frequencies | |
| # TODO: allow user to specify vectorizer settings as in TFIDF.py | |
| tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2, max_features=self.n_features, | |
| stop_words='english') | |
| tf_matrix = tf_vectorizer.fit_transform(X[self.variable]) | |
| # Save the vectorizer feature names into self | |
| self.feature_names = tf_vectorizer.get_feature_names() | |
| nans = self.drop_na_rows(X) | |
| # fit an LDA model | |
| y_hat = self.estimator.fit(tf_matrix) | |
| def predict(self, X): | |
| self.assert_field_present(X, self.variable) | |
| self.drop_unused_fields(X, [self.variable]) | |
| self.drop_na_rows(X) | |
| self.assert_any_fields(X) | |
| self.assert_any_rows(X) | |
| if type(X[self.variable][0]) != str: | |
| raise RuntimeError('Invalid type: "%s" is of type %s. String expected.' % (self.variable, | |
| X[self.variable].dtype)) | |
| nans = self.drop_na_rows(X) | |
| # fit an LDA model | |
| tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2, max_features=self.n_features, | |
| stop_words='english') | |
| tf_matrix = tf_vectorizer.fit_transform(X[self.variable]) | |
| y_hat = self.estimator.transform(tf_matrix) | |
| width = y_hat.shape[1] | |
| length = len(X) | |
| columns = [self.variable + '_LDA_' + str(index) for index in range(len(self.estimator.components_))] | |
| output = pd.DataFrame(np.zeros((length, width)), columns=columns) | |
| output.ix[:, columns] = np.nan | |
| output.ix[~nans, columns] = y_hat | |
| return output | |
| def summary(self): | |
| def get_top_words(model, feature_names, n_top_words): | |
| top_words = [[topic_idx + 1, " ".join([feature_names[i] | |
| for i in topic.argsort()[:-n_top_words - 1:-1]])] | |
| for topic_idx, topic in enumerate(model.components_)] | |
| twdf = pd.DataFrame(top_words, columns=['topic', 'topic_top_words']) | |
| return twdf | |
| tw = get_top_words(self.estimator, self.feature_names, self.n_top_words) | |
| return tw |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment