-
-
Save Gforky/fba1ff298d37536abc54a5ae5195e167 to your computer and use it in GitHub Desktop.
Small python script to rename variables in a TensorFlow checkpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, getopt | |
import tensorflow as tf | |
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \ | |
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run' | |
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run): | |
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) | |
with tf.Session() as sess: | |
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): | |
# Load the variable | |
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) | |
# Set the new name | |
new_name = var_name | |
if None not in [replace_from, replace_to]: | |
new_name = new_name.replace(replace_from, replace_to) | |
if add_prefix: | |
new_name = add_prefix + new_name | |
if dry_run: | |
print('%s would be renamed to %s.' % (var_name, new_name)) | |
else: | |
print('Renaming %s to %s.' % (var_name, new_name)) | |
# Rename the variable | |
var = tf.Variable(var, name=new_name) | |
if not dry_run: | |
# Save the variables | |
saver = tf.train.Saver() | |
sess.run(tf.global_variables_initializer()) | |
saver.save(sess, checkpoint.model_checkpoint_path) | |
def main(argv): | |
checkpoint_dir = None | |
replace_from = None | |
replace_to = None | |
add_prefix = None | |
dry_run = False | |
try: | |
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', | |
'replace_to=', 'add_prefix=', 'dry_run']) | |
except getopt.GetoptError: | |
print(usage_str) | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt in ('-h', '--help'): | |
print(usage_str) | |
sys.exit() | |
elif opt == '--checkpoint_dir': | |
checkpoint_dir = arg | |
elif opt == '--replace_from': | |
replace_from = arg | |
elif opt == '--replace_to': | |
replace_to = arg | |
elif opt == '--add_prefix': | |
add_prefix = arg | |
elif opt == '--dry_run': | |
dry_run = True | |
if not checkpoint_dir: | |
print('Please specify a checkpoint_dir. Usage:') | |
print(usage_str) | |
sys.exit(2) | |
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment