-
-
Save protoget/9b45881f23c96e201a90581c8f4b692d to your computer and use it in GitHub Desktop.
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import numpy as np | |
| import tensorflow as tf | |
| shape = [2, 2, 2] | |
| n_cell_dim = 2 | |
| def init_vars(sess): | |
| sess.run(tf.global_variables_initializer()) | |
| def train_graph(): | |
| with tf.Graph().as_default(), tf.device('/gpu:0'): | |
| with tf.Session() as sess: | |
| is_training = True | |
| inputs = tf.random_uniform(shape, dtype=tf.float32) | |
| lstm = tf.contrib.cudnn_rnn.CudnnLSTM( | |
| num_layers=1, | |
| num_units=n_cell_dim, | |
| direction='bidirectional', | |
| dtype=tf.float32) | |
| lstm.build(inputs.get_shape()) | |
| outputs, output_states = lstm(inputs, training=is_training) | |
| with tf.device('/cpu:0'): | |
| saver = tf.train.Saver() | |
| init_vars(sess) | |
| saver.save(sess, '/tmp/model') | |
| def inf_graph(): | |
| with tf.Graph().as_default(), tf.device('/cpu:0'): | |
| with tf.Session() as sess: | |
| single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell( | |
| n_cell_dim, reuse=tf.get_variable_scope().reuse) | |
| inputs = tf.random_uniform(shape, dtype=tf.float32) | |
| lstm_fw_cell = [single_cell() for _ in range(1)] | |
| lstm_bw_cell = [single_cell() for _ in range(1)] | |
| (outputs, output_state_fw, | |
| output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( | |
| lstm_fw_cell, | |
| lstm_bw_cell, | |
| inputs, | |
| dtype=tf.float32, | |
| time_major=True) | |
| saver = tf.train.Saver() | |
| saver.restore(sess, '/tmp/model') | |
| print(sess.run(outputs)) | |
| def main(unused_argv): | |
| train_graph() | |
| inf_graph() | |
| if __name__ == '__main__': | |
| tf.app.run(main) |
I get the following error in the inf_graph() part when I try to run the above
NotFoundError (see above for traceback): Key stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/lstm_cell/kernel not found in checkpoint
Anything I might be missing?
@dapurv5 maybe it's the scope problem? I suggest you to add cudnn_lstm scope before you restore the checkpoint
Hello,
How can I restore cudnn_LSTM to GPU device? I cant find any solution from google.
Besides, I need to restore the meta graph to get my placeholder for inference, but once I run this statement:
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint)),
I get this error:
KeyError: "The name 'cudnn_lstm/opaque_kernel_saveable' refers to an Operation not in the graph."
Can you help me ? THANKS
Hello @SysuJayce , have you solve this problem? I also get this error:
KeyError: "The name 'cudnn_lstm/opaque_kernel_saveable' refers to an Operation not in the graph."
Can you help me? Thanks.
Any update on this issue? Really stuck at this point.
https://github.com/yjchoe/TFCudnnLSTM in this you can use training saver and ckpt form training ,but you also can't restore .meta
I get the following error in the inf_graph() part when I try to run the above
NotFoundError (see above for traceback): Key stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/lstm_cell/kernel not found in checkpoint
Anything I might be missing?