Last active
December 15, 2018 22:44
-
-
Save kykosic/a7e25ae7e225d8dbd44fcfc17db4c88d to your computer and use it in GitHub Desktop.
Save timeline traces of hardware metadata in Keras model.fit() calls.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Keras Callback object for creating model run hardware profile timelines. | |
This callback is very slow, and should only be used for debugging. | |
""" | |
import os | |
import json | |
import logging | |
import tensorflow as tf | |
from tensorflow.python.client.timeline import Timeline | |
def add_batch_id(batch_id, trace_dict): | |
""" Add batch_id to the 'args' of all Op events in the trace_dict """ | |
for event in trace_dict['traceEvents']: | |
if event['ph'] == 'X': | |
event['args'].update({'batch_id': batch_id}) | |
class Profiler(tf.keras.callbacks.Callback): | |
""" Callback object for storing timeline profiles given run_metadata """ | |
def __init__(self, run_metadata, output_file): | |
""" | |
Args: | |
run_metadata (tf.RunMetadata): RunMetadata object for extracting timeline events. | |
output_file (str): Location to store JSON timeline events in chrome trace format. | |
""" | |
self.run_metadata = run_metadata | |
self.output_file = output_file | |
def on_train_begin(self, logs=None): | |
self._events = list() | |
def on_batch_end(self, batch, logs=None): | |
self._events.append(Timeline(self.run_metadata.step_stats)) | |
def on_train_end(self, logs=None): | |
timeline_dict = dict() | |
for i, timeline in enumerate(self._events): | |
trace_dict = json.loads(timeline.generate_chrome_trace_format()) | |
add_batch_id(i, trace_dict) | |
if i == 0: | |
timeline_dict = trace_dict | |
else: | |
for event in trace_dict['traceEvents']: | |
if 'ts' in event: | |
timeline_dict['traceEvents'].append(event) | |
self.write_timeline(timeline_dict) | |
def write_timeline(self, timeline_dict): | |
dirname = os.path.dirname(os.path.realpath(self.output_file)) | |
if not os.path.exists(dirname): | |
os.makedirs(dirname) | |
with open(self.output_file, 'w') as f: | |
json.dump(timeline_dict, f) | |
logging.info("Profile trace saved to %s", self.output_file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment