Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save yrevar/432df272d3b800900ffbcbafa0ed6516 to your computer and use it in GitHub Desktop.
Save yrevar/432df272d3b800900ffbcbafa0ed6516 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from matplotlib import gridspec\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def visualize_img_activations(img, activations, cmap=\"gray\", \n",
" prefix=\"\", luminance_scale=(None, None),\n",
" fontsize=24):\n",
" \n",
" h, w = img.shape\n",
" k, ah, aw = activations.shape\n",
" vmin, vmax = luminance_scale\n",
" \n",
" # Activations subplot: nrow x ncol\n",
" nrows = int(np.ceil(np.sqrt(k)))\n",
" ncols = int(np.ceil(1. * k / nrows))\n",
" \n",
" # Input image span\n",
" img_rows = nrows if nrows < 16 else nrows // 2\n",
" img_cols = img_rows\n",
" \n",
" # Main grid dimensions\n",
" grid_rows, grid_cols = nrows, ncols + img_cols\n",
" gridspec.GridSpec(grid_rows, grid_cols)\n",
" \n",
" # Plot image\n",
" plt.subplot2grid((grid_rows, grid_cols), (0,0), colspan=img_cols, rowspan=img_rows)\n",
" \n",
" if vmin:\n",
" plt.imshow(img, cmap, interpolation=None, vmin=vmin, vmax=vmax)\n",
" else:\n",
" plt.imshow(img, cmap, interpolation=None)\n",
" \n",
" plt.xticks([]), plt.yticks([])\n",
" plt.ylabel(\"Input\", fontsize=fontsize)\n",
" \n",
" # Plot activations\n",
" for r in range(nrows):\n",
" for c in range(ncols):\n",
" \n",
" a_idx = r * ncols + c\n",
" if a_idx < len(activations):\n",
" plt.subplot2grid((grid_rows, grid_cols), (r, img_cols+c), colspan=1, rowspan=1)\n",
" if r == 0 and c == int(np.floor(ncols/2)):\n",
" plt.title(\"{} {:s}\".format(prefix, \"Activations\"), fontsize=fontsize)\n",
" if vmin:\n",
" plt.imshow(activations[a_idx], cmap, interpolation=None, vmin=vmin, vmax=vmax)\n",
" else:\n",
" plt.imshow(activations[a_idx], cmap, interpolation=None)\n",
" \n",
" plt.xticks([]), plt.yticks([])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "irl_python3",
"language": "python",
"name": "irl_python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment