Skip to content

Instantly share code, notes, and snippets.

@jotterbach
Last active February 17, 2021 18:40
Show Gist options
  • Select an option

  • Save jotterbach/230b4b0680f4b3871f7e7779f407f368 to your computer and use it in GitHub Desktop.

Select an option

Save jotterbach/230b4b0680f4b3871f7e7779f407f368 to your computer and use it in GitHub Desktop.
Utility to compute the flop count of various elementary neural network modules
import numpy as np
def conv_out_shape(in_shape: tuple, kernel_size:int, dilation:int, stride:int, padding:int):
"""
Convenience function for computing the output shape of an image tensor, when applying symmetric kernels, strides, etc.
"""
out_dim = []
for d in in_shape:
_d_out = d + 2 * padding - dilation * (kernel_size - 1) - 1
_d_out = _d_out / stride
_d_out = np.floor(_d_out + 1)
out_dim.append(int(_d_out))
return tuple(out_dim)
def flops_per_conv(top_shape, in_channels, out_channels, kernel_size:int, dilation:int, stride:int, padding:int):
"""
A convolution cna be interpreted as a smart way to do weight-sharing. The basic idea is the `im2col` operation that extracts
striding patches of dimension `d_in = in_channels * kernelsize**2` from the image. Those patches are then multiplied by a matrix
of size `d_in x out_channels`. As a consequence the effective batch-size for the convolution matrix is increased by the number of
patches, resulting from the `im2col` operation.
see also:
- https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/
- https://indico.cern.ch/event/917049/contributions/3856417/attachments/2034165/3405345/Quantized_CNN_LLP.pdf
"""
out_shape = conv_out_shape(top_shape, kernel_size, dilation, stride, padding)
num_patches = np.prod(out_shape)
in_dim = kernel_size**2 * in_channels
out_dim = out_channels
weight_params = in_dim * out_dim
bias_params = out_dim
flops = 2 * num_patches * weight_params + bias_params
return (out_channels, *out_shape), (weight_params, bias_params), flops
def flops_per_dense(in_dim, out_dim):
"""
see also: https://indico.cern.ch/event/917049/contributions/3856417/attachments/2034165/3405345/Quantized_CNN_LLP.pdf
"""
weight_params = in_dim * out_dim
bias_params = out_dim
flops = 2 * weight_params + bias_params
return (out_dim,), (weight_params, bias_params), flops
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment