Last active
August 24, 2016 07:57
-
-
Save MasazI/fd3418c8866813d70835f574e1e60a35 to your computer and use it in GitHub Desktop.
tensorflow saver.restore
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(): | |
with tf.Graph().as_default(): | |
# 初期化オペレーション | |
init_op = tf.initialize_all_variables() | |
# Session | |
sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT)) | |
sess.run(init_op) | |
all_variables = tf.all_variables() | |
# filterd variables by name, parameters list that I need. | |
scale1and2_params = [] | |
for variable in all_variables: | |
# ... | |
scale1and2_params.append(variable) | |
saver_scale1and2 = tf.train.Saver(scale1and2_params) | |
# load checkpoint from CloudStorage | |
# FLAGS.train_s12_dir is "gs://<projectid>/cnn_depth_labels/<pretrained jobid>/train12" | |
scale1and2_ckpt = tf.train.get_checkpoint_state(FLAGS.train_s12_dir) | |
if scale1and2_ckpt and scale1and2_ckpt.model_checkpoint_path: | |
# scale1and2_ckpt.model_checkpoint_path is "gs://<projectid>-ml/cnn_depth_labels/<pretrained jobid>/train12/model.ckpt-90" | |
print("Pretrained scale1and2 Model Loading. %s" % (scale1and2_ckpt.model_checkpoint_path)) | |
saver_scale1and2.restore(sess, scale1and2_ckpt.model_checkpoint_path) | |
# ->>> exception |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment