Created
February 6, 2019 05:40
-
-
Save HarshSingh16/6cf763befce2c65f40a1dfcb8cc35530 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Defining the decoder layers RNN | |
def decoder_rnn(decoder_embedded_input,decoder_embeddings_matrix,encoder_state,num_words,sequence_length,rnn_size, num_layers,word2int,keep_prob,batch_size): | |
with tf.variable_scope("decoding") as decoding_scope: | |
lstm= tf.contrib.rnn.BasicLSTMCell(rnn_size) | |
lstm_dropout= tf.contrib.rnn.DropoutWrapper(lstm, input_keep_prob=keep_prob) | |
decoder_cell= tf.contrib.rnn.MultiRNNCell([lstm_dropout]*num_layers) | |
weights = tf.truncated_normal_initializer(stddev = 0.1) | |
biases = tf.zeros_initializer() | |
output_function = lambda x: tf.contrib.layers.fully_connected(x, | |
num_words, | |
None, | |
scope=decoding_scope, | |
weights_initializer=weights, | |
biases_initializer=biases) | |
training_predictions=decode_training_set(encoder_state, | |
decoder_cell, | |
decoder_embedded_input, | |
sequence_length, | |
decoding_scope, | |
output_function, | |
keep_prob, | |
batch_size) | |
decoding_scope.reuse_variables() | |
test_predictions=decode_test_set(encoder_state, | |
decoder_cell, | |
decoder_embeddings_matrix, | |
answersswordstoint["<SOS>"], | |
answersswordstoint["<EOS>"], | |
sequence_length-1, | |
num_words, | |
decoding_scope, | |
output_function, | |
keep_prob, | |
batch_size) | |
return training_predictions,test_predictions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment