Last active
February 17, 2021 18:40
-
-
Save jotterbach/230b4b0680f4b3871f7e7779f407f368 to your computer and use it in GitHub Desktop.
Utility to compute the flop count of various elementary neural network modules
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| def conv_out_shape(in_shape: tuple, kernel_size:int, dilation:int, stride:int, padding:int): | |
| """ | |
| Convenience function for computing the output shape of an image tensor, when applying symmetric kernels, strides, etc. | |
| """ | |
| out_dim = [] | |
| for d in in_shape: | |
| _d_out = d + 2 * padding - dilation * (kernel_size - 1) - 1 | |
| _d_out = _d_out / stride | |
| _d_out = np.floor(_d_out + 1) | |
| out_dim.append(int(_d_out)) | |
| return tuple(out_dim) | |
| def flops_per_conv(top_shape, in_channels, out_channels, kernel_size:int, dilation:int, stride:int, padding:int): | |
| """ | |
| A convolution cna be interpreted as a smart way to do weight-sharing. The basic idea is the `im2col` operation that extracts | |
| striding patches of dimension `d_in = in_channels * kernelsize**2` from the image. Those patches are then multiplied by a matrix | |
| of size `d_in x out_channels`. As a consequence the effective batch-size for the convolution matrix is increased by the number of | |
| patches, resulting from the `im2col` operation. | |
| see also: | |
| - https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/ | |
| - https://indico.cern.ch/event/917049/contributions/3856417/attachments/2034165/3405345/Quantized_CNN_LLP.pdf | |
| """ | |
| out_shape = conv_out_shape(top_shape, kernel_size, dilation, stride, padding) | |
| num_patches = np.prod(out_shape) | |
| in_dim = kernel_size**2 * in_channels | |
| out_dim = out_channels | |
| weight_params = in_dim * out_dim | |
| bias_params = out_dim | |
| flops = 2 * num_patches * weight_params + bias_params | |
| return (out_channels, *out_shape), (weight_params, bias_params), flops | |
| def flops_per_dense(in_dim, out_dim): | |
| """ | |
| see also: https://indico.cern.ch/event/917049/contributions/3856417/attachments/2034165/3405345/Quantized_CNN_LLP.pdf | |
| """ | |
| weight_params = in_dim * out_dim | |
| bias_params = out_dim | |
| flops = 2 * weight_params + bias_params | |
| return (out_dim,), (weight_params, bias_params), flops |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment