Created
May 10, 2018 06:11
-
-
Save hlzz/f0604e4bd2ce9cc6511048b1ba665788 to your computer and use it in GitHub Desktop.
tensorflow restore model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def recover_from_pretrained(sess, net, ckpt_path=None, ckpt_step=-1, np_model_path=None): | |
""" | |
recover from pretrained ckpt graph or numpy net parameters | |
:param sess: the current session | |
:param net: the feature tower | |
:param ckpt_path: the ckpt root path (if there is any) | |
:param ckpt_step: the ckpt step number (if there is any) | |
:param np_model_path: the numpy model path (if there is any) | |
:return: the correct step index for the current | |
""" | |
if ckpt_step is not None: # Finetune from ckpt if necessary. | |
ckpt_name = '-'.join([ckpt_path, str(ckpt_step)]) | |
if os.path.exists(ckpt_name + '.index'): | |
restore_variable = [] | |
for i in tf.global_variables(): | |
if i.name.find('GeM') < 0: | |
restore_variable.append(i) | |
restorer = tf.train.Saver(restore_variable) | |
# restorer = tf.train.Saver(tf.global_variables()) | |
restorer.restore(sess, ckpt_name) | |
print(Notify.INFO, 'Pre-trained model restored from', ckpt_name, Notify.ENDC) | |
return ckpt_step | |
else: | |
print(Notify.WARNING, 'ckpt file %s does not exist, begin training from scratch' % | |
ckpt_name + '.index', Notify.ENDC) | |
elif np_model_path != '': # Finetune from a numpy file | |
print(Notify.INFO, 'Recover from pre-trained model', np_model_path, Notify.ENDC) | |
net.load(np_model_path, sess, ignore_missing=True) | |
else: | |
print(Notify.WARNING, 'Train from scratch.', Notify.ENDC) | |
return 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
when a new variable called 'GeM' is added, restore from a previous model.