Skip to content

Instantly share code, notes, and snippets.

@kenfehling
Forked from kylemcdonald/summary.py
Last active July 17, 2018 09:36
Show Gist options
  • Save kenfehling/57c43052bfa1efac5342cd55cabc9ca3 to your computer and use it in GitHub Desktop.
Save kenfehling/57c43052bfa1efac5342cd55cabc9ca3 to your computer and use it in GitHub Desktop.
Pytorch model summary
from collections import OrderedDict
import torch as th
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
fmt = ' {:<20} {:>15} {:>15} {:>10}'
def stringify_shape(shape):
return str(shape) \
.replace('[', '') \
.replace(']', '') \
.replace(' ', '') \
.replace('-1', '_')
def rnn_to_cnn_shape(x):
return x.transpose(0, 1).transpose(1, 2)
def summary(input_size, model):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = '%s-%i' % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
if isinstance(input, (list, tuple)): # RNN
input = input[0]
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key]['input_shape'][0] = -1
if isinstance(output, (list, tuple)): # RNN
output = rnn_to_cnn_shape(output[0])
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = -1
params = 0
if hasattr(module, 'weight') and module.weight is not None:
params += th.prod(th.LongTensor(list(module.weight.size())))
if module.weight.requires_grad:
summary[m_key]['trainable'] = True
else:
summary[m_key]['trainable'] = False
if hasattr(module, 'bias') and isinstance(module.bias, Parameter):
params += th.prod(th.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params'] = params
if not isinstance(module, nn.Sequential) and \
not isinstance(module, nn.ModuleList) and \
not (module == model):
hooks.append(module.register_forward_hook(hook))
dtype = th.FloatTensor
# check if there are multiple inputs to the network
if isinstance(input_size[0], (list, tuple)):
x = [Variable(th.rand(in_size)).type(dtype) for in_size in input_size]
else:
x = Variable(th.rand(input_size)).type(dtype)
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
model(x)
# remove these hooks
for h in hooks:
h.remove()
print('-------------------------------------------------------------------')
head = fmt.format('Layer (type)', 'Input Shape', 'Output Shape', 'Param #')
print(head)
print('===================================================================')
total_params = 0
trainable_params = 0
for layer in summary:
total_params += summary[layer]['nb_params']
if 'trainable' in summary[layer]:
if summary[layer]['trainable']:
trainable_params += summary[layer]['nb_params']
print(fmt.format(
layer,
stringify_shape(summary[layer]['input_shape']),
stringify_shape(summary[layer]['output_shape']),
'{:,}'.format(summary[layer]['nb_params'])))
nontrainable_params = total_params - trainable_params
print('===================================================================')
print('Total params: {:,}'.format(total_params.item()))
print('Trainable params: {:,}'.format(trainable_params.item()))
print('Non-trainable params: {:,}'.format(nontrainable_params.item()))
print('-------------------------------------------------------------------')
return summary
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment