Skip to content

Instantly share code, notes, and snippets.

@MasazI
Last active August 24, 2016 07:57
Show Gist options
  • Save MasazI/fd3418c8866813d70835f574e1e60a35 to your computer and use it in GitHub Desktop.
Save MasazI/fd3418c8866813d70835f574e1e60a35 to your computer and use it in GitHub Desktop.
tensorflow saver.restore
def train():
with tf.Graph().as_default():
# 初期化オペレーション
init_op = tf.initialize_all_variables()
# Session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT))
sess.run(init_op)
all_variables = tf.all_variables()
# filterd variables by name, parameters list that I need.
scale1and2_params = []
for variable in all_variables:
# ...
scale1and2_params.append(variable)
saver_scale1and2 = tf.train.Saver(scale1and2_params)
# load checkpoint from CloudStorage
# FLAGS.train_s12_dir is "gs://<projectid>/cnn_depth_labels/<pretrained jobid>/train12"
scale1and2_ckpt = tf.train.get_checkpoint_state(FLAGS.train_s12_dir)
if scale1and2_ckpt and scale1and2_ckpt.model_checkpoint_path:
# scale1and2_ckpt.model_checkpoint_path is "gs://<projectid>-ml/cnn_depth_labels/<pretrained jobid>/train12/model.ckpt-90"
print("Pretrained scale1and2 Model Loading. %s" % (scale1and2_ckpt.model_checkpoint_path))
saver_scale1and2.restore(sess, scale1and2_ckpt.model_checkpoint_path)
# ->>> exception
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment