Created
July 30, 2015 14:39
-
-
Save takuti/e6975eb6f755b3fbc188 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
""" Usage | |
$ ./IBM_Model1.py ${english_filename} ${japanese_filename} | |
NOTE: prefix for filenames 'kftt-alignments/data/' will be automatically added | |
""" | |
# Reference: http://www.statmt.org/book/slides/04-word-based-models.pdf | |
import sys | |
def main(): | |
if len(sys.argv) < 3: | |
exit('[Usage] $ ./IBM_Model1.py ${english_filename} ${japanese_filename}') | |
""" Load data | |
""" | |
english_filename = sys.argv[1] | |
with open('kftt-alignments/data/%s' % english_filename) as f: | |
e_sentences = map(lambda l: l.rstrip().split(' '), f.readlines()) | |
e_tokens = set([t for inner_list in e_sentences for t in inner_list]) | |
japanese_filename = sys.argv[2] | |
with open('kftt-alignments/data/%s' % japanese_filename) as f: | |
j_sentences = map(lambda l: l.rstrip().decode('utf-8').split(' '), f.readlines()) | |
j_tokens = set([t for inner_list in j_sentences for t in inner_list]) | |
t = {} | |
for e in e_tokens: | |
for j in j_tokens: | |
t[(e, j)] = .25 | |
""" Model learning using EM-algorithms | |
""" | |
count = {} | |
total = {} | |
eps = 1e-3 | |
while True: | |
# initialization | |
for e in e_tokens: | |
for j in j_tokens: | |
count[(e, j)] = 0. | |
for j in j_tokens: | |
total[j] = 0. | |
# Maximization | |
s_total = {} | |
for e_sentence, j_sentence in zip(e_sentences, j_sentences): | |
for e in e_sentence: | |
s_total[e] = 0. | |
for j in j_sentence: | |
s_total[e] += t[(e, j)] | |
for e in e_sentence: | |
for j in j_sentence: | |
count[(e, j)] += t[(e, j)] / s_total[e] | |
total[j] += t[(e, j)] / s_total[e] | |
# Expectation | |
cnt = 0 | |
for e in e_tokens: | |
for j in j_tokens: | |
new_t = count[(e, j)] / total[j] | |
if abs(new_t - t[(e, j)]) > eps: cnt += 1 | |
t[(e, j)] = new_t | |
if cnt == 0: break | |
else: print cnt # for debug | |
print 'total: %d English tokens <-> %d Japanese tokens' % (len(e_tokens), len(j_tokens)) | |
for e, j in t.keys(): | |
if t[(e, j)] > .9: print e, j, t[(e, j)] | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment