Skip to content

Instantly share code, notes, and snippets.

@stas00
Created October 15, 2021 03:16
Show Gist options
  • Save stas00/4cd1651d1c8f01196ea322c733bde46c to your computer and use it in GitHub Desktop.
Save stas00/4cd1651d1c8f01196ea322c733bde46c to your computer and use it in GitHub Desktop.
tensorboard rename event tags (based on https://stackoverflow.com/a/60080531/9201239)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# this script renames event names in tensorboard log files
# it does the rename in place (so make back ups!)
#
# example:
#
# find . -name "*.tfevents*" -exec tb-rename-events.py {} "iteration-time" "iteration-time/iteration-time" \;
#
# more than one old tag can be remapped to one new tag - use `;` as a separator:
#
# tb-rename-events.py events.out.tfevents.1 "training loss;validation loss" "loss"
#
# this script is derived from https://stackoverflow.com/a/60080531/9201239
#
import sys
from pathlib import Path
import os
# Use this if you want to avoid using the GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
from tensorflow.core.util.event_pb2 import Event
def rename_events(input_file, old_tags, new_tag):
new_file = input_file + ".new"
# Make a record writer
with tf.io.TFRecordWriter(new_file) as writer:
# Iterate event records
for rec in tf.data.TFRecordDataset([input_file]):
# Read event
ev = Event()
ev.MergeFromString(rec.numpy())
# Check if it is a summary
if ev.summary:
# Iterate summary values
for v in ev.summary.value:
# Check if the tag should be renamed
if v.tag in old_tags:
# Rename with new tag name
v.tag = new_tag
writer.write(ev.SerializeToString())
os.rename(new_file, input_file)
def rename_events_dir(input_file, old_tags, new_tag):
# Write renamed events
rename_events(input_file, old_tags, new_tag)
if __name__ == '__main__':
if len(sys.argv) != 4:
print(f'{sys.argv[0]} <input file> <old tags> <new tag>',
file=sys.stderr)
sys.exit(1)
input_file, old_tags, new_tag = sys.argv[1:]
old_tags = old_tags.split(';')
rename_events_dir(input_file, old_tags, new_tag)
print('Done')
@boxiXia
Copy link

boxiXia commented Apr 18, 2022

thanks for the gist! I made a minor change to support renaming multiple set of names and from multiple directories:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# this script renames event names in tensorboard log files
# it does the rename in place (so make back ups!)
#
# example:
#
# find . -name "*.tfevents*" -exec python tb-rename-events.py {} --set old/name/1=new/name/1 old/name/2=new/name/2 \;
#
# python tb-rename-events.py "path/to/tensorboard/logs_1" "path/to/tensorboard/logs_2" ... \
#     --ext tb_extention_name_1  tb_extention_name_2 ... \
#     --set old/name/1=new/name/1 old/name/2=new/name/2 ...
#
# this script is derived from https://stackoverflow.com/a/60080531/9201239
# and from https://gist.github.com/stas00/4cd1651d1c8f01196ea322c733bde46c
#

import ray
from itertools import chain
from tensorflow.core.util.event_pb2 import Event
import tensorflow as tf
import os
import argparse
import glob
# from joblib import Parallel, delayed
# Use this if you want to avoid using the GPU
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf.config.set_visible_devices([], 'GPU')


@ray.remote
def rename_events(input_file, tag_mappings):
    print(f"rename: {input_file}")
    new_file = input_file + ".new"
    # Make a record writer
    with tf.io.TFRecordWriter(new_file) as writer:
        # Iterate event records
        for rec in tf.data.TFRecordDataset([input_file]):
            # Read event
            ev = Event()
            ev.MergeFromString(rec.numpy())
            # Check if it is a summary
            if ev.summary:
                # Iterate summary values
                for v in ev.summary.value:
                    # Check if the tag should be renamed
                    if v.tag in tag_mappings:
                        # Rename with new tag name
                        v.tag = tag_mappings[v.tag]
            writer.write(ev.SerializeToString())
    os.rename(new_file, input_file)


# https://stackoverflow.com/questions/27146262/create-variable-key-value-pairs-with-argparse-python
# https://gist.github.com/fralau/061a4f6c13251367ef1d9a9a99fb3e8d
class ParseDict(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        d = getattr(namespace, self.dest) or {}

        if values:
            for item in values:
                split_items = item.split("=", 1)
                key = split_items[
                    0
                ].strip()  # we remove blanks around keys, as is logical
                value = split_items[1]

                d[key] = value

        setattr(namespace, self.dest, d)


if __name__ == '__main__':
    """
    example usage:
        ```bash
        python tb-rename-events.py "path/to/tensorboard/logs_1" "path/to/tensorboard/logs_2" ... \
            --ext tb_extention_name_1  tb_extention_name_2 ... \
            --set old/name/1=new/name/1 old/name/2=new/name/2 ...
        ```
    """
    parser = argparse.ArgumentParser(description="...")
    parser.add_argument('path', metavar='PATH', type=str, nargs='+',
                        help='directory to the tensorboard files')
    parser.add_argument(
        "--ext",
        metavar=".EXT",
        type=str,
        nargs='+',
        default="*",
        help="specify the extensiton name to search for"
        "for exmaple if you want to include *.abc and *.def, you can pass --ext abc def"
    )
    parser.add_argument(
        "--set",
        metavar="KEY=VALUE",
        nargs="+",
        help="Set a number of key-value pairs "
        "(do not put spaces before or after the = sign). "
        "If a value contains spaces, you should define "
        "it with double quotes: "
        'foo="this is a sentence". Note that '
        "values are always treated as strings.",
        action=ParseDict,
    )
    args = parser.parse_args()
    print(args)

    paths = list(chain.from_iterable(
        glob.glob(f"{path}/**/*.{ext}", recursive=True)
        for path in args.path
        for ext in args.ext))
    tag_mappings = args.set
    print(tag_mappings)
    print(paths)
    ray.init()
    tag_mappings_id = ray.put(tag_mappings)
    task_list = [rename_events.remote(path, tag_mappings_id) for path in paths]
    ray.get(task_list)

@stas00
Copy link
Author

stas00 commented May 19, 2022

Thank you for sharing your improvements, @boxiXia!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment