Requirement: plotly
, pandas
and pytorch
import pandas as pd
import plotly.express as px
def plot_weights_treemap(model, max_levels=10):
"""Display pytorch module hierchachy in a treemap diagram
(The cell area is proportional to the size of the tensors)
"""
param_df = pd.DataFrame(
[{'name': name, 'n_weights': param.shape.numel(), 'shape': param.shape}
for name, param in model.named_parameters()])
paths = param_df.name.str.split('.', max_levels, expand=True)
paths_cols = [f"path_{i}" for i in range(paths.shape[1])]
paths = paths.rename(columns={i: col for i, col in enumerate(paths_cols)})
paths['n_weights'] = param_df['n_weights']
paths['names'] = param_df['shape'].map(lambda shape: 'x'.join([str(s) for s in shape]))
return px.treemap(paths, path=paths_cols, values='n_weights', hover_name='names')
model = ... # Pytorch module with `.named_parameters()` method.
plot_weights_treemap(model)
Example with Huggingface TransformerXL
model:
And GPT2
model:
Copyright (c) 2021 Martin Sotir. All rights reserved. This work is licensed under the terms of the MIT license. For a copy, see https://opensource.org/licenses/MIT.