import numpy as np
from scipy.special import digamma

from ..utils.document import Document


class bigramTopicModel(object):
    def __init__(self, K: int, docs: Document, S=10):
        self.K = K
        self._documents = docs.get_documents()
        self._V = docs.get_nb_vocab()
        self._D = docs.get_nb_docs()
        self.doc_lens = docs.get_doc_lengths()
        self._beta_mk = np.ones((self.K, self._V)) / self._V
        self._alphas = np.ones(self.K) * 0.01
        self._sum_alpha = np.sum(self._alphas)
        self._nkcv = np.zeros((self.K, self._V+1, self._V)).astype(np.int32)  ## (topic, previous word, word)
        self._ndk = np.zeros((self._D, self.K)).astype(np.int32)
        self._nk = np.zeros(self.K).astype(np.int32)
        self._z = []
        self._S = S

    def fit(self, nb_iterations=300, nb_hyper_iterations=10):
        self._initialize_topics()
        for ite in range(1, nb_iterations+1):
            print("\r", ite, end="")
            self._m_step(self._e_step(), nb_hyper_iterations=nb_hyper_iterations)

    def _initialize_topics(self):
        for doc_id, doc in enumerate(self._documents):
            doc_topic = np.random.randint(self.K, size=doc.shape[0], dtype=np.int32)
            self._z.append(doc_topic)
            pre_word = self._V
            for word, topic in zip(doc, doc_topic):
                self._nkcv[topic, pre_word, word] += 1
                self._ndk[doc_id, topic] += 1
                self._nk[topic] += 1
                pre_word = word


    def _e_step(self):
        beta_k = np.sum(self._beta_mk, axis=1)
        sampling_Z = np.zeros((self._S, self.K, self._D), dtype=np.int32)
        for s in range(self._S):
            for d in range(self._D):
                w_d = self._documents[d]
                pre_w = self._V
                for i, w in enumerate(w_d):
                    t = self._z[d][i]

                    # remove topic
                    self._nkcv[t, pre_w, w] -= 1
                    self._nk[t] -= 1
                    self._ndk[d, t] -= 1

                    u = np.array([(self._nkcv[k, pre_w, w] + self._beta_mk[k, w]) /
                                  (np.sum(self._nkcv[k, pre_w, :]) +  beta_k[k]) * (self._ndk[d, k] +  self._alphas[k])
                                  for k in range(self.K)])

                    U = np.random.rand() * np.sum(u)

                    for k, u_v in enumerate(u):
                        U -= u_v
                        if U < 0.:
                            t = k
                            break

                    # add topic
                    self._nkcv[t, pre_w, w] += 1
                    self._nk[t] += 1
                    self._ndk[d, t] += 1
                    sampling_Z[s, t, d] += 1


                    self._z[d][i] = t
                    pre_w = w
        return sampling_Z

    def _m_step(self, zs: np.ndarray, nb_hyper_iterations: int):
        # update alpha
        for _ in range(nb_hyper_iterations):
            for k in range(self.K):
                numer = -digamma(self._alphas[k])*self._S*self._D
                denom = -digamma(self._sum_alpha)*self._S*self._D

                for s in range(self._S):
                    for d in range(self._D):
                        numer += digamma(zs[s, k, d] + self._alphas[k])
                        denom += digamma(self.doc_lens[d] + self._sum_alpha)

                self._alphas[k] = max(self._alphas[k] * numer/denom, 10e-30)
                self._sum_alpha = np.sum(self._alphas)

        # TODO: UPDATE BETA

    def word_predict(self, k: int):
        return np.sum(self._nkcv[k, :, :], axis=0)

    def topic_predict(self, doc_id: int):
        p = self._ndk[doc_id, :] + self._alphas
        return p / np.sum(p)