Created
February 18, 2018 12:34
-
-
Save KentaKudo/fa4c36b37d7df018873de7217e630ae3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
from keras.preprocessing.text import Tokenizer | |
from sklearn.model_selection import train_test_split | |
from keras.preprocessing.sequence import pad_sequences | |
def load_dataset(file_path): | |
tokenizer = Tokenizer(filters="") | |
texts = [] | |
for line in open(file_path, 'r'): | |
texts.append("<s> " + line.strip() + " </s>") | |
tokenizer.fit_on_texts(texts) | |
return tokenizer.texts_to_sequences(texts), tokenizer | |
def decode_sequence(input_seq): | |
states_value = encoder_model.predict(input_seq) | |
bos_eos = tokenizer_j.texts_to_sequences(["<s>", "</s>"]) | |
target_seq = np.array(bos_eos[0]) | |
output_seq = bos_eos[0] | |
while True: | |
output_tokens, h, c = decoder_model.predict( | |
[target_seq] + states_value | |
) | |
sampled_token_index = [np.argmax(output_tokens[0, -1, :])] | |
output_seq += sampled_token_index | |
if (sampled_token_index == bos_eos[1] or len(output_seq) > 1000): | |
break | |
target_seq = np.array(sampled_token_index) | |
states_value = [h, c] | |
return output_seq | |
train_X, tokenizer_e = load_dataset('tanaka_corpus_e.txt') | |
train_Y, tokenizer_j = load_dataset('tanaka_corpus_j.txt') | |
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=0.02, random_state=42) | |
train_X = pad_sequences(train_X, padding='post') | |
train_Y = pad_sequences(train_Y, padding='post') | |
seqX_len = len(train_X[0]) | |
seqY_len = len(train_Y[0]) | |
word_num_e = len(tokenizer_e.word_index) + 1 | |
word_num_j = len(tokenizer_j.word_index) + 1 | |
from keras.models import Model | |
from keras.layers import Input, Embedding, Dense, LSTM | |
emb_dim = 256 | |
hid_dim = 256 | |
encoder_inputs = Input(shape=(seqX_len,)) | |
encoder_embedded = Embedding(word_num_e, emb_dim, mask_zero=True)(encoder_inputs) | |
encoder = LSTM(hid_dim, return_state=True) | |
_, state_h, state_c = encoder(encoder_embedded) | |
encoder_states = [state_h, state_c] | |
decoder_inputs = Input(shape=(seqY_len,)) | |
decoder_embedding = Embedding(word_num_j, emb_dim) | |
decoder_embedded = decoder_embedding(decoder_inputs) | |
decoder = LSTM(hid_dim, return_sequences=True, return_state=True) | |
decoder_outputs, _, _ = decoder(decoder_embedded, initial_state=encoder_states) | |
decoder_dense = Dense(word_num_j, activation='softmax') | |
decoder_outputs = decoder_dense(decoder_outputs) | |
model = Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy') | |
decoder_target_data = np.hstack((train_Y[:, 1:], np.zeros((len(train_Y),1), dtype=np.int32))) | |
model.fit([train_X, train_Y], np.expand_dims(decoder_target_data, -1), batch_size=128, epochs=1, verbose=2, validation_split=0.2) | |
# outputs encoder_states no matter what the inputs are. | |
encoder_model = Model(encoder_inputs, encoder_states) | |
decoder_state_input_h = Input(shape=(hid_dim,)) # will receive encoder_state_h as input | |
decoder_state_input_c = Input(shape=(hid_dim,)) # will receive encoder_state_c as input | |
decoder_state_inputs = [decoder_state_input_h, decoder_state_input_c] | |
decoder_inputs = Input(shape=(1,)) | |
decoder_embedded = decoder_embedding(decoder_inputs) | |
decoder_outputs, state_h, state_c = decoder( | |
decoder_embedded, initial_state=decoder_state_inputs | |
) | |
decoder_states = [state_h, state_c] | |
decoder_outputs = decoder_dense(decoder_outputs) | |
decoder_model = Model( | |
[decoder_inputs] + decoder_state_inputs, | |
[decoder_outputs] + decoder_states | |
) | |
detokenizer_e = dict(map(reversed, tokenizer_e.word_index.items())) | |
detokenizer_j = dict(map(reversed, tokenizer_j.word_index.items())) | |
detokenizer_j[0] = '' # paddingが出力されることがあったためズルをする | |
input_seq = pad_sequences([test_X[0]], seqX_len, padding='post') | |
print(' '.join([detokenizer_e[i] for i in test_X[0]])) | |
print(' '.join([detokenizer_j[i] for i in decode_sequence(input_seq)])) | |
print(' '.join([detokenizer_j[i] for i in test_Y[0]])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment