Created
May 17, 2018 04:56
-
-
Save hamelsmu/dda1468ef5e12c3f11c878d095abfa78 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_seq2seq_model(word_emb_dim, | |
hidden_state_dim, | |
encoder_seq_len, | |
num_encoder_tokens, | |
num_decoder_tokens): | |
""" | |
Builds architecture for sequence to sequence model. | |
Encoder and Decoder layer consist of one GRU Layer each. User | |
can specify the dimensionality of the word embedding and the hidden state. | |
Parameters | |
---------- | |
word_emb_dim : int | |
dimensionality of the word embeddings | |
hidden_state_dim : int | |
dimensionality of the hidden state in the encoder and decoder. | |
encoder_seq_len : int | |
the length of the sequences that are input into the encoder. The | |
sequences are expected to all be padded to the same size. | |
num_encoder_tokens : int | |
the vocabulary size of the corpus relevant to the encoder. | |
num_decoder_tokens : int | |
the vocabulary size of the corpus relevant to the decoder. | |
Returns | |
------- | |
Keras.models.Model | |
""" | |
#### Encoder Model #### | |
encoder_inputs = Input(shape=(encoder_seq_len,), name='Encoder-Input') | |
# Word embeding for encoder (ex: Issue Titles, Code) | |
x = Embedding(num_encoder_tokens, word_emb_dim, name='Body-Word-Embedding', mask_zero=False)(encoder_inputs) | |
x = BatchNormalization(name='Encoder-Batchnorm-1')(x) | |
# We do not need the `encoder_output` just the hidden state. | |
_, state_h = GRU(hidden_state_dim, return_state=True, name='Encoder-Last-GRU', dropout=.5)(x) | |
# Encapsulate the encoder as a separate entity so we can just | |
# encode without decoding if we want to. | |
encoder_model = Model(inputs=encoder_inputs, outputs=state_h, name='Encoder-Model') | |
seq2seq_encoder_out = encoder_model(encoder_inputs) | |
#### Decoder Model #### | |
decoder_inputs = Input(shape=(None,), name='Decoder-Input') # for teacher forcing | |
# Word Embedding For Decoder (ex: Issue Titles, Docstrings) | |
dec_emb = Embedding(num_decoder_tokens, word_emb_dim, name='Decoder-Word-Embedding', mask_zero=False)(decoder_inputs) | |
dec_bn = BatchNormalization(name='Decoder-Batchnorm-1')(dec_emb) | |
# Set up the decoder, using `decoder_state_input` as initial state. | |
decoder_gru = GRU(hidden_state_dim, return_state=True, return_sequences=True, name='Decoder-GRU', dropout=.5) | |
decoder_gru_output, _ = decoder_gru(dec_bn, initial_state=seq2seq_encoder_out) | |
x = BatchNormalization(name='Decoder-Batchnorm-2')(decoder_gru_output) | |
# Dense layer for prediction | |
decoder_dense = Dense(num_decoder_tokens, activation='softmax', name='Final-Output-Dense') | |
decoder_outputs = decoder_dense(x) | |
#### Seq2Seq Model #### | |
seq2seq_Model = Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
return seq2seq_Model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment