Created
January 27, 2017 12:15
-
-
Save nzw0301/bd919e154a0cd6ff0529fe73d65d58ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.special import digamma | |
from ..utils.document import Document | |
class bigramTopicModel(object): | |
def __init__(self, K: int, docs: Document, S=10): | |
self.K = K | |
self._documents = docs.get_documents() | |
self._V = docs.get_nb_vocab() | |
self._D = docs.get_nb_docs() | |
self.doc_lens = docs.get_doc_lengths() | |
self._beta_mk = np.ones((self.K, self._V)) / self._V | |
self._alphas = np.ones(self.K) * 0.01 | |
self._sum_alpha = np.sum(self._alphas) | |
self._nkcv = np.zeros((self.K, self._V+1, self._V)).astype(np.int32) ## (topic, previous word, word) | |
self._ndk = np.zeros((self._D, self.K)).astype(np.int32) | |
self._nk = np.zeros(self.K).astype(np.int32) | |
self._z = [] | |
self._S = S | |
def fit(self, nb_iterations=300, nb_hyper_iterations=10): | |
self._initialize_topics() | |
for ite in range(1, nb_iterations+1): | |
print("\r", ite, end="") | |
self._m_step(self._e_step(), nb_hyper_iterations=nb_hyper_iterations) | |
def _initialize_topics(self): | |
for doc_id, doc in enumerate(self._documents): | |
doc_topic = np.random.randint(self.K, size=doc.shape[0], dtype=np.int32) | |
self._z.append(doc_topic) | |
pre_word = self._V | |
for word, topic in zip(doc, doc_topic): | |
self._nkcv[topic, pre_word, word] += 1 | |
self._ndk[doc_id, topic] += 1 | |
self._nk[topic] += 1 | |
pre_word = word | |
def _e_step(self): | |
beta_k = np.sum(self._beta_mk, axis=1) | |
sampling_Z = np.zeros((self._S, self.K, self._D), dtype=np.int32) | |
for s in range(self._S): | |
for d in range(self._D): | |
w_d = self._documents[d] | |
pre_w = self._V | |
for i, w in enumerate(w_d): | |
t = self._z[d][i] | |
# remove topic | |
self._nkcv[t, pre_w, w] -= 1 | |
self._nk[t] -= 1 | |
self._ndk[d, t] -= 1 | |
u = np.array([(self._nkcv[k, pre_w, w] + self._beta_mk[k, w]) / | |
(np.sum(self._nkcv[k, pre_w, :]) + beta_k[k]) * (self._ndk[d, k] + self._alphas[k]) | |
for k in range(self.K)]) | |
U = np.random.rand() * np.sum(u) | |
for k, u_v in enumerate(u): | |
U -= u_v | |
if U < 0.: | |
t = k | |
break | |
# add topic | |
self._nkcv[t, pre_w, w] += 1 | |
self._nk[t] += 1 | |
self._ndk[d, t] += 1 | |
sampling_Z[s, t, d] += 1 | |
self._z[d][i] = t | |
pre_w = w | |
return sampling_Z | |
def _m_step(self, zs: np.ndarray, nb_hyper_iterations: int): | |
# update alpha | |
for _ in range(nb_hyper_iterations): | |
for k in range(self.K): | |
numer = -digamma(self._alphas[k])*self._S*self._D | |
denom = -digamma(self._sum_alpha)*self._S*self._D | |
for s in range(self._S): | |
for d in range(self._D): | |
numer += digamma(zs[s, k, d] + self._alphas[k]) | |
denom += digamma(self.doc_lens[d] + self._sum_alpha) | |
self._alphas[k] = max(self._alphas[k] * numer/denom, 10e-30) | |
self._sum_alpha = np.sum(self._alphas) | |
# TODO: UPDATE BETA | |
def word_predict(self, k: int): | |
return np.sum(self._nkcv[k, :, :], axis=0) | |
def topic_predict(self, doc_id: int): | |
p = self._ndk[doc_id, :] + self._alphas | |
return p / np.sum(p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment