-
-
Save kenfehling/57c43052bfa1efac5342cd55cabc9ca3 to your computer and use it in GitHub Desktop.
Pytorch model summary
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import torch as th | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from torch.nn import Parameter | |
fmt = ' {:<20} {:>15} {:>15} {:>10}' | |
def stringify_shape(shape): | |
return str(shape) \ | |
.replace('[', '') \ | |
.replace(']', '') \ | |
.replace(' ', '') \ | |
.replace('-1', '_') | |
def rnn_to_cnn_shape(x): | |
return x.transpose(0, 1).transpose(1, 2) | |
def summary(input_size, model): | |
def register_hook(module): | |
def hook(module, input, output): | |
class_name = str(module.__class__).split('.')[-1].split("'")[0] | |
module_idx = len(summary) | |
m_key = '%s-%i' % (class_name, module_idx + 1) | |
summary[m_key] = OrderedDict() | |
if isinstance(input, (list, tuple)): # RNN | |
input = input[0] | |
summary[m_key]['input_shape'] = list(input[0].size()) | |
summary[m_key]['input_shape'][0] = -1 | |
if isinstance(output, (list, tuple)): # RNN | |
output = rnn_to_cnn_shape(output[0]) | |
summary[m_key]['output_shape'] = list(output.size()) | |
summary[m_key]['output_shape'][0] = -1 | |
params = 0 | |
if hasattr(module, 'weight') and module.weight is not None: | |
params += th.prod(th.LongTensor(list(module.weight.size()))) | |
if module.weight.requires_grad: | |
summary[m_key]['trainable'] = True | |
else: | |
summary[m_key]['trainable'] = False | |
if hasattr(module, 'bias') and isinstance(module.bias, Parameter): | |
params += th.prod(th.LongTensor(list(module.bias.size()))) | |
summary[m_key]['nb_params'] = params | |
if not isinstance(module, nn.Sequential) and \ | |
not isinstance(module, nn.ModuleList) and \ | |
not (module == model): | |
hooks.append(module.register_forward_hook(hook)) | |
dtype = th.FloatTensor | |
# check if there are multiple inputs to the network | |
if isinstance(input_size[0], (list, tuple)): | |
x = [Variable(th.rand(in_size)).type(dtype) for in_size in input_size] | |
else: | |
x = Variable(th.rand(input_size)).type(dtype) | |
# create properties | |
summary = OrderedDict() | |
hooks = [] | |
# register hook | |
model.apply(register_hook) | |
# make a forward pass | |
model(x) | |
# remove these hooks | |
for h in hooks: | |
h.remove() | |
print('-------------------------------------------------------------------') | |
head = fmt.format('Layer (type)', 'Input Shape', 'Output Shape', 'Param #') | |
print(head) | |
print('===================================================================') | |
total_params = 0 | |
trainable_params = 0 | |
for layer in summary: | |
total_params += summary[layer]['nb_params'] | |
if 'trainable' in summary[layer]: | |
if summary[layer]['trainable']: | |
trainable_params += summary[layer]['nb_params'] | |
print(fmt.format( | |
layer, | |
stringify_shape(summary[layer]['input_shape']), | |
stringify_shape(summary[layer]['output_shape']), | |
'{:,}'.format(summary[layer]['nb_params']))) | |
nontrainable_params = total_params - trainable_params | |
print('===================================================================') | |
print('Total params: {:,}'.format(total_params.item())) | |
print('Trainable params: {:,}'.format(trainable_params.item())) | |
print('Non-trainable params: {:,}'.format(nontrainable_params.item())) | |
print('-------------------------------------------------------------------') | |
return summary |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment