import torch

# Dict to store hooks and flop count
data_dict = {'conv_flops' : 0, 'hooks' :[]}

def count_conv_flops(self, input, output):
    # Flop contribution from channelwise connections
    flops_c = self.out_channels * self.in_channels / self.groups
    # Flop contribution from number of spatial locations we convolve over
    flops_s = output.size(2) * output.size(3)
    # Flop contribution from number of mult-adds at each location
    flops_f = self.kernel_size[0] * self.kernel_size[1]
    data_dict['conv_flops'] += flops_c * flops_s * flops_f
    return
    
def add_hooks(m):
    if isinstance(m, torch.nn.Conv2d):
        data_dict['hooks'] += [m.register_forward_hook(count_conv_flops)]
    return

def count_flops(model, x):    
    data_dict['conv_flops'] = 0
    # Note if we need to return the model to training mode
    set_train = model.training
    model.eval()
    model.apply(add_hooks)
    out = model(torch.autograd.Variable(x.data, volatile=True))
    for hook in data_dict['hooks']:
        hook.remove()
    if set_train:
        model.train()
    return data_dict['conv_flops']