Created
March 26, 2021 18:46
-
-
Save amaarora/c562da34b95d97f8254960bdca6a12d1 to your computer and use it in GitHub Desktop.
Plot SPPs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "gradient": {}, | |
| "trusted": false | |
| }, | |
| "id": "satellite-superior", | |
| "cell_type": "code", | |
| "source": "import timm \nimport torch \nimport torchvision\nfrom matplotlib import pyplot as plt", | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": false | |
| }, | |
| "id": "preceding-space", | |
| "cell_type": "code", | |
| "source": "def avg_sq_ch_mean(model, input, output): \n \"calculate average channel square mean of output activations\"\n return torch.mean(output.mean(axis=[0,2,3])**2).item()\n\n\ndef avg_ch_var(model, input, output): \n \"calculate average channel variance of output activations\"\n return torch.mean(output.var(axis=[0,2,3])).item()", | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": false | |
| }, | |
| "id": "romance-oakland", | |
| "cell_type": "code", | |
| "source": "class ActivationStatsHook:\n \"\"\"Iterates through each of `model`'s modules and if module's class name \n is present in `layer_names` then registers `hook_fns` inside that module\n and stores activation stats inside `self.stats`.\n\n Arguments:\n model (nn.Module): model from which we will extract the activation stats\n layer_names (List[str]): The layer name to look for to register forward \n hook. Example, `BasicBlock`, `Bottleneck`\n hook_fns (List[Callable]): List of hook functions to be registered at every\n module in `layer_names`.\n \n Inspiration from https://docs.fast.ai/callback.hook.html.\n \"\"\"\n\n def __init__(self, model, layer_names, hook_fns=[avg_sq_ch_mean, avg_ch_var]):\n self.model = model\n self.layer_names = layer_names \n self.hook_fns = hook_fns\n self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)\n for hook_fn in hook_fns: \n self.register_hook(layer_names, hook_fn)\n\n def _create_hook(self, hook_fn):\n def append_activation_stats(module, input, output):\n out = hook_fn(module, input, output)\n self.stats[hook_fn.__name__].append(out)\n return append_activation_stats\n \n def register_hook(self, layer_names, hook_fn):\n for layer in self.model.modules():\n layer_name = layer.__class__.__name__\n if layer_name not in layer_names: \n continue\n layer.register_forward_hook(self._create_hook(hook_fn))", | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": false | |
| }, | |
| "id": "incomplete-jason", | |
| "cell_type": "code", | |
| "source": "def extract_spp_stats(model, \n layer_names, \n hook_fns=[avg_sq_ch_mean, avg_ch_var], \n input_shape=[8, 3, 224, 224]):\n \"\"\"Extract average square channel mean and variance of activations during \n forward pass to plot Signal Propogation Plots (SPP).\n \n Paper: https://arxiv.org/abs/2101.08692\n \"\"\" \n x = torch.normal(0., 1., input_shape)\n hook = ActivationStatsHook(model, layer_names, hook_fns)\n _ = model(x)\n return hook.stats", | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "gradient": {}, | |
| "trusted": false | |
| }, | |
| "id": "filled-singles", | |
| "cell_type": "code", | |
| "source": "m = timm.create_model('nfnet_f0')\nstats = extract_spp_stats(m, layer_names=['NormFreeBlock'], hook_fns=[avg_sq_ch_mean, avg_ch_var])\nplt.plot(stats['avg_sq_ch_mean'], label='avg_sq_ch_mean');\nplt.plot(stats['avg_ch_var'], label='avg_ch_var');\nplt.legend();", | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": "<Figure size 432x288 with 1 Axes>" | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "gradient": {}, | |
| "trusted": false | |
| }, | |
| "id": "foreign-damages", | |
| "cell_type": "code", | |
| "source": "m = torchvision.models.resnet50()\nstats = extract_spp_stats(m, layer_names=['Bottleneck'], hook_fns=[avg_sq_ch_mean, avg_ch_var])\nplt.plot(stats['avg_sq_ch_mean'], label='avg_sq_ch_mean');\nplt.plot(stats['avg_ch_var'], label='avg_ch_var');\nplt.legend();", | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": "<Figure size 432x288 with 1 Axes>" | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3", | |
| "language": "python" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.8.5", | |
| "mimetype": "text/x-python", | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "pygments_lexer": "ipython3", | |
| "nbconvert_exporter": "python", | |
| "file_extension": ".py" | |
| }, | |
| "gist": { | |
| "id": "", | |
| "data": { | |
| "description": "Plot SPPs", | |
| "public": true | |
| } | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment