Skip to content

Instantly share code, notes, and snippets.

@kykosic
Last active December 15, 2018 22:44
Show Gist options
  • Save kykosic/a7e25ae7e225d8dbd44fcfc17db4c88d to your computer and use it in GitHub Desktop.
Save kykosic/a7e25ae7e225d8dbd44fcfc17db4c88d to your computer and use it in GitHub Desktop.
Save timeline traces of hardware metadata in Keras model.fit() calls.
"""
Keras Callback object for creating model run hardware profile timelines.
This callback is very slow, and should only be used for debugging.
"""
import os
import json
import logging
import tensorflow as tf
from tensorflow.python.client.timeline import Timeline
def add_batch_id(batch_id, trace_dict):
""" Add batch_id to the 'args' of all Op events in the trace_dict """
for event in trace_dict['traceEvents']:
if event['ph'] == 'X':
event['args'].update({'batch_id': batch_id})
class Profiler(tf.keras.callbacks.Callback):
""" Callback object for storing timeline profiles given run_metadata """
def __init__(self, run_metadata, output_file):
"""
Args:
run_metadata (tf.RunMetadata): RunMetadata object for extracting timeline events.
output_file (str): Location to store JSON timeline events in chrome trace format.
"""
self.run_metadata = run_metadata
self.output_file = output_file
def on_train_begin(self, logs=None):
self._events = list()
def on_batch_end(self, batch, logs=None):
self._events.append(Timeline(self.run_metadata.step_stats))
def on_train_end(self, logs=None):
timeline_dict = dict()
for i, timeline in enumerate(self._events):
trace_dict = json.loads(timeline.generate_chrome_trace_format())
add_batch_id(i, trace_dict)
if i == 0:
timeline_dict = trace_dict
else:
for event in trace_dict['traceEvents']:
if 'ts' in event:
timeline_dict['traceEvents'].append(event)
self.write_timeline(timeline_dict)
def write_timeline(self, timeline_dict):
dirname = os.path.dirname(os.path.realpath(self.output_file))
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(self.output_file, 'w') as f:
json.dump(timeline_dict, f)
logging.info("Profile trace saved to %s", self.output_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment