Created
December 8, 2015 07:53
-
-
Save nzw0301/fd8572d1d50f24dad68e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import sys | |
import math | |
import numpy as np | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn import preprocessing | |
import scipy.special | |
fname = sys.argv[1] | |
cv = CountVectorizer(token_pattern='(?u)\\b\\w') | |
docs = [] | |
# データの読み込み | |
with open(fname) as f: | |
for l in f: | |
docs.append(l.strip()) | |
Docs = cv.fit_transform(docs) | |
del docs | |
z = [-1]*Docs.shape[0] # topic idを管理,indexは文書のidと一致 | |
K = 2 # トピック数 | |
D_k = [0] * K # 要素:トピックkが割り当てらてた文書数 | |
N_k = [0] * K # 要素:トピックkが割り当てらてた文書の総単語数 | |
V = Docs.shape[1] # 語彙数 | |
N_kv = [] # 各トピックにおける単語頻度 | |
for k in range(K): | |
N_kv.append([0] * V) | |
alpha = 1 | |
beta = 1 | |
# end init | |
for i in range(100): | |
print("------------\n") | |
print("ite:", i, " α= ", alpha, " β= ", beta) | |
print("トピックの割り当て", z) | |
for doc_id, doc in enumerate(Docs): | |
z_d = z[doc_id] # 文書にあてられたトピックid | |
not_zero_indices = doc.indices # 非ゼロ要素の単語index | |
N_d = sum(doc.toarray()[0]) # 文書dの総単語数 | |
# 現在の文書の分を抜く | |
if z_d >= 0: | |
D_k[z_d] -= 1 | |
for index in not_zero_indices: | |
N_kv[z_d][index] -= doc[0, index] | |
N_k[z_d] -= N_d | |
pro = [] # サンプリング式の値を格納 | |
for k in range(K): | |
first = (D_k[k] + alpha) * math.gamma(N_k[k]+(beta*V)) / math.gamma(N_k[k] + N_d + (beta*V)) | |
second = 1.0 | |
for index in not_zero_indices: | |
second *= math.gamma(N_kv[k][index] + doc[0, index] + beta) / math.gamma(N_kv[k][index] + beta) | |
pro.append(first*second) | |
pro = preprocessing.normalize(pro, norm="l1")[0] # 正規化 | |
z_d = np.nonzero(np.random.multinomial(1, pro))[0][0] # トピックの割り当て | |
z[doc_id] = z_d # 文書のトピックを更新 | |
print("\ndoc id = ", doc_id) | |
for k in range(K): | |
print("topic",k,"'s pro =", pro[k]) | |
# print(pro, doc_id) # 出力 | |
D_k[z_d] += 1 | |
N_k[z_d] += N_d | |
for index in not_zero_indices: | |
N_kv[z_d][index] += doc[0, index] | |
# update α | |
numerator = 0.0 | |
for k in range(K): | |
numerator += scipy.special.psi(D_k[k] + alpha) | |
numerator -= K * scipy.special.psi(alpha) | |
alpha = alpha * numerator / (K * scipy.special.psi(Docs.shape[0] + alpha * K) - K * scipy.special.psi(alpha * K)) | |
# end update α | |
numerator = 0.0 | |
denominator = 0.0 | |
# update β | |
for k in range(K): | |
for v in range(V): | |
numerator += scipy.special.psi(N_kv[k][v] + beta) | |
denominator += scipy.special.psi(N_k[k]+beta*V) | |
numerator -= K*V*scipy.special.psi(beta) | |
denominator = V * denominator - K*V*scipy.special.psi(beta*V) | |
beta = beta*numerator/denominator | |
# end update β |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment