import torch # Dict to store hooks and flop count data_dict = {'conv_flops' : 0, 'hooks' :[]} def count_conv_flops(self, input, output): # Flop contribution from channelwise connections flops_c = self.out_channels * self.in_channels / self.groups # Flop contribution from number of spatial locations we convolve over flops_s = output.size(2) * output.size(3) # Flop contribution from number of mult-adds at each location flops_f = self.kernel_size[0] * self.kernel_size[1] data_dict['conv_flops'] += flops_c * flops_s * flops_f return def add_hooks(m): if isinstance(m, torch.nn.Conv2d): data_dict['hooks'] += [m.register_forward_hook(count_conv_flops)] return def count_flops(model, x): data_dict['conv_flops'] = 0 # Note if we need to return the model to training mode set_train = model.training model.eval() model.apply(add_hooks) out = model(torch.autograd.Variable(x.data, volatile=True)) for hook in data_dict['hooks']: hook.remove() if set_train: model.train() return data_dict['conv_flops']