-
-
Save shashankg7/f5bfe5f0753aef7fb03e3cad2f45ab33 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def symdirichlet(alpha, n): | |
v = np.zeros(n)+alpha | |
return np.random.dirichlet(v) | |
def exp_digamma(x): | |
if x < 0.1: | |
return x/100 | |
a = x*x | |
b = a*x | |
c = b*x | |
return x - 0.5 + 1./(24*x) | |
exp_digamma = np.vectorize(exp_digamma) | |
import re | |
wre = re.compile(r"(\w)+") | |
def get_words(text, stop=True): | |
"A simple tokenizer" | |
l = 0 | |
while l < len(text): | |
s = wre.search(text,l) | |
try: | |
w = text[s.start():s.end()].lower() | |
if stop: | |
yield w | |
elif w not in stoplist: | |
yield w | |
l = s.end() | |
except: | |
break | |
class EmLda(object): | |
def __init__(self, docs, nt): | |
self.Nt = nt | |
self.docs = [] | |
self.all_words = [] | |
self.reverse_map = {} | |
self.Nd = 0 | |
for d in docs: | |
doc = [] | |
self.docs.append(doc) | |
self.Nd += 1 | |
for w in get_words(d): | |
if len(w) < 5: continue | |
if not w in self.reverse_map: | |
self.reverse_map[w] = len(self.all_words) | |
self.all_words.append(w) | |
doc.append(self.reverse_map[w]) | |
self.V = len(self.all_words) | |
self.ctopics = np.zeros((self.Nt, self.V)) | |
self.doccounts = np.zeros((self.Nd, self.Nt)) | |
self.ptopics = np.zeros((self.Nt, self.V)) | |
self.docprobs = np.zeros((self.Nd, self.Nt)) | |
self.beta = 1. | |
self.alpha = 1. | |
for d in xrange(self.Nd): | |
for i,w in enumerate(self.docs[d]): | |
zw = symdirichlet(self.beta, self.Nt) | |
self.ctopics.T[w] += zw | |
self.doccounts[d] += zw | |
self.m() | |
def e(self): | |
self.ctopics.fill(0) | |
self.doccounts.fill(0) | |
for d in xrange(self.Nd): | |
for i,w in enumerate(self.docs[d]): | |
zw = self.ptopics.T[w].copy() | |
zw *= self.docprobs[d] | |
zw /= np.sum(zw) | |
self.ctopics.T[w] += zw | |
self.doccounts[d] += zw | |
def m(self): | |
self.ptopics.fill(0) | |
self.docprobs.fill(0) | |
self.ptopics += self.ctopics + self.alpha | |
self.ptopics /= np.sum(self.ctopics, axis=1).reshape((-1,1)) | |
self.docprobs += self.doccounts + self.beta | |
self.docprobs /= np.sum(self.docprobs, axis=1).reshape((-1,1)) | |
def iterate(self): | |
self.e() | |
self.m() | |
def run(self, n): | |
for i in xrange(n): | |
print "iter", i | |
import sys | |
sys.stdout.flush() | |
self.iterate() | |
for w in xrange(100): | |
print "word", self.all_words[w], | |
print "topics", self.ptopics.T[w]/np.sum(self.ptopics.T[w]) | |
for i in xrange(self.Nt): | |
print "Topic", i | |
print_topic(self, self.ptopics[i], 40) | |
class PrLda(EmLda): | |
def __init__(self, *args): | |
super(PrLda, self).__init__(*args) | |
self.Nw = sum(len(d) for d in self.docs) | |
c = 0 | |
self.sigma = 15. | |
def e(self): | |
self.do_lambda() | |
self.do_z() | |
def do_lambda(self): | |
self.ptheta = [[] for i in xrange(self.V)] | |
self.ctopics.fill(0) | |
self.doccounts.fill(0) | |
for d in xrange(self.Nd): | |
for i,w in enumerate(self.docs[d]): | |
zw = self.ptopics.T[w].copy() | |
zw *= self.docprobs[d] | |
zw /= np.sum(zw) | |
self.ptheta[w].append(zw) | |
self.lbda = [] | |
self.ptheta = map(np.array, self.ptheta) | |
for w in xrange(self.V): | |
self.lbda.append(self.optimize_lambda(self.ptheta[w])) | |
def optimize_lambda(self, ptheta, steps=50, lrate=1.): | |
lbda = np.ones(ptheta.shape[1]) | |
lbda /= np.sum(lbda).T | |
lbda *= self.sigma | |
prevobj = np.inf | |
n = 0 | |
while True: | |
obj = np.sum(np.sum(ptheta,axis=0).T*np.exp(-lbda)) | |
if n > 5 and prevobj - obj < 0.01*obj: | |
break | |
n += 1 | |
prevobj = obj | |
# do the gradient descent | |
lbda += lrate*np.sum(ptheta,axis=0)*np.exp(-lbda) | |
# truncate | |
lbda *= lbda>0 | |
# project it into the l1 ball with diameter sigma | |
ll = -np.sort(-lbda) | |
cs = np.argmin(((ll-self.sigma).cumsum()/(np.arange(len(ll))+1.))) >= 0 | |
theta = ll[cs-1] | |
lbda -= theta | |
lbda *= lbda > 0 | |
return lbda | |
def do_z(self): | |
indices = [0 for i in xrange(self.V)] | |
for d in xrange(self.Nd): | |
for i,w in enumerate(self.docs[d]): | |
zw = self.ptheta[w][indices[w]] | |
zw *= np.exp(-self.lbda[w]) | |
zw /= np.sum(zw) | |
indices[w] += 1 | |
self.ctopics.T[w] += zw | |
self.doccounts[d] += zw | |
class VarLda(EmLda): | |
def e(self): | |
self.ctopics.fill(0) | |
self.doccounts.fill(0) | |
for d in xrange(self.Nd): | |
for i,w in enumerate(self.docs[d]): | |
zw = self.ptopics.T[w].copy() | |
zw *= self.docprobs[d] | |
zw = exp_digamma(zw)/exp_digamma(np.sum(zw)) | |
self.ctopics.T[w] += zw | |
self.doccounts[d] += zw | |
def print_topic(model, t, n): | |
s = np.argsort(-t) | |
for w in s[:n]: | |
print " ",model.all_words[w] | |
if __name__ == '__main__': | |
import sys | |
docs = [] | |
nt = int(sys.argv[1]) | |
import os | |
for fname in os.listdir(sys.argv[2]): | |
if not fname.startswith("."): | |
docs.append(file(os.path.join(sys.argv[2],fname)).read()) | |
el = VarLda(docs, nt) | |
el.run(10) | |
el = EmLda(docs, nt) | |
el.run(10) | |
el = PrLda(docs, nt) | |
el.run(10) | |
el = PrLda(docs, nt) | |
el.sigma = 150 | |
el.run(10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment