Skip to content

Instantly share code, notes, and snippets.

@hlzz
Created May 10, 2018 06:11
Show Gist options
  • Save hlzz/f0604e4bd2ce9cc6511048b1ba665788 to your computer and use it in GitHub Desktop.
Save hlzz/f0604e4bd2ce9cc6511048b1ba665788 to your computer and use it in GitHub Desktop.
tensorflow restore model
def recover_from_pretrained(sess, net, ckpt_path=None, ckpt_step=-1, np_model_path=None):
"""
recover from pretrained ckpt graph or numpy net parameters
:param sess: the current session
:param net: the feature tower
:param ckpt_path: the ckpt root path (if there is any)
:param ckpt_step: the ckpt step number (if there is any)
:param np_model_path: the numpy model path (if there is any)
:return: the correct step index for the current
"""
if ckpt_step is not None: # Finetune from ckpt if necessary.
ckpt_name = '-'.join([ckpt_path, str(ckpt_step)])
if os.path.exists(ckpt_name + '.index'):
restore_variable = []
for i in tf.global_variables():
if i.name.find('GeM') < 0:
restore_variable.append(i)
restorer = tf.train.Saver(restore_variable)
# restorer = tf.train.Saver(tf.global_variables())
restorer.restore(sess, ckpt_name)
print(Notify.INFO, 'Pre-trained model restored from', ckpt_name, Notify.ENDC)
return ckpt_step
else:
print(Notify.WARNING, 'ckpt file %s does not exist, begin training from scratch' %
ckpt_name + '.index', Notify.ENDC)
elif np_model_path != '': # Finetune from a numpy file
print(Notify.INFO, 'Recover from pre-trained model', np_model_path, Notify.ENDC)
net.load(np_model_path, sess, ignore_missing=True)
else:
print(Notify.WARNING, 'Train from scratch.', Notify.ENDC)
return 0
@hlzz
Copy link
Author

hlzz commented May 10, 2018

when a new variable called 'GeM' is added, restore from a previous model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment