Skip to content

Instantly share code, notes, and snippets.

@cigrainger
Created May 29, 2014 07:03
Show Gist options
  • Select an option

  • Save cigrainger/5e0dc2f638ea72f0edbc to your computer and use it in GitHub Desktop.

Select an option

Save cigrainger/5e0dc2f638ea72f0edbc to your computer and use it in GitHub Desktop.
# Imports and housekeeping
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
level=logging.INFO)
from gensim import corpora, models, similarities, matutils
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
# Define KL function
def sym_kl(p,q):
return np.sum([stats.entropy(p,q),stats.entropy(q,p)])
# Generate corpus
stoplist = set(open('stoplist.txt','r').read().split())
dictionary = corpora.Dictionary(line.lower().split() for
line in open('abstracts.txt','rb'))
stop_ids = [dictionary.token2id[stopword] for
stopword in stoplist if stopword in dictionary.token2id]
once_ids = [tokenid for tokenid, docfreq in
dictionary.dfs.iteritems() if docfreq == 1]
dictionary.filter_tokens(stop_ids + once_ids)
dictionary.filter_extremes(no_above=5,keep_n=100000)
dictionary.compactify()
class MyCorpus(object):
def __iter__(self):
for line in open('abstracts.txt','rb'):
yield dictionary.doc2bow(line.lower().split())
my_corpus = MyCorpus()
corpora.MmCorpus.serialize('corpus.mm', my_corpus)
# Run models to find natural number of topics
kl = []
l = np.array([sum(cnt for _, cnt in doc) for doc in my_corpus])
num = range(1,150,1)
for i in num:
lda = models.ldamodel.LdaModel(corpus=my_corpus,
id2word=dictionary,num_topics=i,distributed=True)
#Topic-word matrix
m1 = lda.expElogbeta
U,cm1,V = np.linalg.svd(m1)
#Document-topic matrix
lda_topics = lda[my_corpus]
m2 = matutils.corpus2dense(lda_topics, lda.num_topics).transpose()
cm2 = l.dot(m2)
# cm2 = cm2 + 0.0001
cm2norm = np.linalg.norm(l)
cm2 = cm2/cm2norm
div = sym_kl(cm1,cm2)
kl.append(div)
# Plot kl divergence against number of topics -- line and bins
#plt.subplot(211)
plt.plot(kl)
plt.ylabel('Symmetric KL Divergence')
plt.xlabel('Number of Topics')
#plt.subplot(212)
#plt.hist(kl)
#plt.show()
#plt.close()
plt.savefig('kldiv.png', bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment