-
-
Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
import sys, getopt | |
import tensorflow as tf | |
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \ | |
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run' | |
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run): | |
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) | |
with tf.Session() as sess: | |
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): | |
# Load the variable | |
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) | |
# Set the new name | |
new_name = var_name | |
if None not in [replace_from, replace_to]: | |
new_name = new_name.replace(replace_from, replace_to) | |
if add_prefix: | |
new_name = add_prefix + new_name | |
if dry_run: | |
print('%s would be renamed to %s.' % (var_name, new_name)) | |
else: | |
print('Renaming %s to %s.' % (var_name, new_name)) | |
# Rename the variable | |
var = tf.Variable(var, name=new_name) | |
if not dry_run: | |
# Save the variables | |
saver = tf.train.Saver() | |
sess.run(tf.global_variables_initializer()) | |
saver.save(sess, checkpoint.model_checkpoint_path) | |
def main(argv): | |
checkpoint_dir = None | |
replace_from = None | |
replace_to = None | |
add_prefix = None | |
dry_run = False | |
try: | |
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', | |
'replace_to=', 'add_prefix=', 'dry_run']) | |
except getopt.GetoptError: | |
print(usage_str) | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt in ('-h', '--help'): | |
print(usage_str) | |
sys.exit() | |
elif opt == '--checkpoint_dir': | |
checkpoint_dir = arg | |
elif opt == '--replace_from': | |
replace_from = arg | |
elif opt == '--replace_to': | |
replace_to = arg | |
elif opt == '--add_prefix': | |
add_prefix = arg | |
elif opt == '--dry_run': | |
dry_run = True | |
if not checkpoint_dir: | |
print('Please specify a checkpoint_dir. Usage:') | |
print(usage_str) | |
sys.exit(2) | |
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
Thanks a lot. It's exactly what I have been finding for a long time!
Thank you for sharing the code. It's very helpful
Thanks a lot, great script!!
I improved it a little adding a few other options to look for a specific key (variable name) and to compare the variables in two checkpoints.
In case it can be of some help: https://gist.github.com/fvisin/578089ae098424590d3f25567b6ee255
thank you for your great helpful code!!!
Thank you, this code is so coooooooooooool :)
I got an Error: ValueError: GraphDef cannot be larger than 2GB.
@batzner
Thanks for your cool code, but I got this error,
saver.save(sess, checkpoint.model_checkpoint_path) AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
however, the conversion has actually done.
Why is it?
Thank you
I wrote a loop with the rename function:
for i in range(1,7): # expaned_conv_i
node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)
but the variable have them duplicated with a suffix "_1"
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean_1
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance_1
Anyone knows why it's happening?
Thanks @bver and @spacetrain, I changed it!