-
-
Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
| import sys, getopt | |
| import tensorflow as tf | |
| usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \ | |
| '--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run' | |
| def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run): | |
| checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) | |
| with tf.Session() as sess: | |
| for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): | |
| # Load the variable | |
| var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) | |
| # Set the new name | |
| new_name = var_name | |
| if None not in [replace_from, replace_to]: | |
| new_name = new_name.replace(replace_from, replace_to) | |
| if add_prefix: | |
| new_name = add_prefix + new_name | |
| if dry_run: | |
| print('%s would be renamed to %s.' % (var_name, new_name)) | |
| else: | |
| print('Renaming %s to %s.' % (var_name, new_name)) | |
| # Rename the variable | |
| var = tf.Variable(var, name=new_name) | |
| if not dry_run: | |
| # Save the variables | |
| saver = tf.train.Saver() | |
| sess.run(tf.global_variables_initializer()) | |
| saver.save(sess, checkpoint.model_checkpoint_path) | |
| def main(argv): | |
| checkpoint_dir = None | |
| replace_from = None | |
| replace_to = None | |
| add_prefix = None | |
| dry_run = False | |
| try: | |
| opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', | |
| 'replace_to=', 'add_prefix=', 'dry_run']) | |
| except getopt.GetoptError: | |
| print(usage_str) | |
| sys.exit(2) | |
| for opt, arg in opts: | |
| if opt in ('-h', '--help'): | |
| print(usage_str) | |
| sys.exit() | |
| elif opt == '--checkpoint_dir': | |
| checkpoint_dir = arg | |
| elif opt == '--replace_from': | |
| replace_from = arg | |
| elif opt == '--replace_to': | |
| replace_to = arg | |
| elif opt == '--add_prefix': | |
| add_prefix = arg | |
| elif opt == '--dry_run': | |
| dry_run = True | |
| if not checkpoint_dir: | |
| print('Please specify a checkpoint_dir. Usage:') | |
| print(usage_str) | |
| sys.exit(2) | |
| rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run) | |
| if __name__ == '__main__': | |
| main(sys.argv[1:]) |
Thank you for this.
Thanks for this!
Just one thing is that this script will not save a variable if its name does not contain "replace_from"
So I modified lines 23, 24 to this,
if new_name == var_name:
print('%s remains unchanged' % var_name)
var = tf.Variable(var, name=new_name)
continue
Cheers
P.S. Ah,, I just noticed that already @bver commented on this. =)
Thanks @bver and @spacetrain, I changed it!
Thanks a lot. It's exactly what I have been finding for a long time!
Thank you for sharing the code. It's very helpful
Thanks a lot, great script!!
I improved it a little adding a few other options to look for a specific key (variable name) and to compare the variables in two checkpoints.
In case it can be of some help: https://gist.github.com/fvisin/578089ae098424590d3f25567b6ee255
thank you for your great helpful code!!!
Thank you, this code is so coooooooooooool :)
I got an Error: ValueError: GraphDef cannot be larger than 2GB.
@batzner
Thanks for your cool code, but I got this error,
saver.save(sess, checkpoint.model_checkpoint_path) AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
however, the conversion has actually done.
Why is it?
Thank you
I wrote a loop with the rename function:
for i in range(1,7): # expaned_conv_i
node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)
but the variable have them duplicated with a suffix "_1"
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean_1
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance_1
Anyone knows why it's happening?
Thank you very much! I was able to reuse my models after upgrading to the recent TF version.
However I had to comment-out lines 23 and 24 since other variables with names not matching --replace_from were not saved back to the checkpoint file. It looks like in previous TF versions checkpoints were modified in-place but this logic changed.
P.