Created
August 9, 2019 20:15
-
-
Save yrevar/765cf3456af4119c9fcabf8667dadb4f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from matplotlib import cm as cm, pyplot as plt, gridspec as gridspec\n", | |
"from matplotlib.patches import ConnectionPatch\n", | |
"%matplotlib inline\n", | |
"\n", | |
"class MatplotlibGridDisplay:\n", | |
" \n", | |
" def __init__(self, rows, cols):\n", | |
" self.rows, self.cols = rows, cols\n", | |
" self.axes_order = {}\n", | |
" \n", | |
" @staticmethod\n", | |
" def _prepare_axis(ax):\n", | |
" \n", | |
" ax.set_xticks([])\n", | |
" ax.set_yticks([])\n", | |
" for sp in ax.spines.values():\n", | |
" sp.set_visible(False)\n", | |
" if ax.is_first_row():\n", | |
" ax.spines['top'].set_visible(True)\n", | |
" if ax.is_last_row():\n", | |
" ax.spines['bottom'].set_visible(True)\n", | |
" if ax.is_first_col():\n", | |
" ax.spines['left'].set_visible(True)\n", | |
" if ax.is_last_col():\n", | |
" ax.spines['right'].set_visible(True)\n", | |
" \n", | |
" return ax\n", | |
" \n", | |
" def _xy_to_rowcol(self, x, y):\n", | |
" \"\"\"Converts (x, y) to (row, col).\n", | |
"\n", | |
" \"\"\"\n", | |
" return self.rows - y, x - 1\n", | |
"\n", | |
" def _rowcol_to_xy(self, row, col):\n", | |
" \"\"\"Converts (row, col) to (x, y).\n", | |
"\n", | |
" \"\"\"\n", | |
" return col + 1, self.rows - row\n", | |
" \n", | |
" def connect_axes(self, fig, ax1, ax2, order=\"forward\"):\n", | |
" \n", | |
" axis_center = (0., 0.)\n", | |
" if order == \"forward\":\n", | |
" con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n", | |
" coordsA=\"data\", coordsB=\"data\",\n", | |
" axesA=ax1, axesB=ax2, color=\"red\", \n", | |
" mutation_scale=40, arrowstyle=\"->\", \n", | |
" shrinkB=5)\n", | |
" ax1.add_artist(con)\n", | |
" else:\n", | |
" con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n", | |
" coordsA=\"data\", coordsB=\"data\",\n", | |
" axesA=ax2, axesB=ax1, color=\"red\", \n", | |
" mutation_scale=40, arrowstyle=\"<-\",\n", | |
" shrinkB=5)\n", | |
" ax2.add_artist(con)\n", | |
" \n", | |
" ax1.plot(*axis_center,'ro',markersize=10)\n", | |
" ax2.plot(*axis_center,'ro',markersize=10)\n", | |
" \n", | |
" def add_trajectory(self, fig, axes_grid, traj):\n", | |
" \n", | |
" x_list, y_list = tuple(zip(*traj)) # [(x, y), ..] -> [x, ...], [y, ...]\n", | |
" for idx in range(len(x_list)-1):\n", | |
" ax1 = axes_grid[(x_list[idx], y_list[idx])]\n", | |
" ax2 = axes_grid[(x_list[idx+1], y_list[idx+1])]\n", | |
" ax1.set_zorder(-2*idx+1)\n", | |
" ax2.set_zorder(-2*idx)\n", | |
" self.connect_axes(fig, ax1, ax2)\n", | |
" \n", | |
" def add_trajectories(self, fig, axes_grid, traj_lst):\n", | |
" \n", | |
" if traj_lst is not None:\n", | |
" for traj in traj_lst:\n", | |
" self.add_trajectory(fig, axes_grid, traj)\n", | |
" \n", | |
" def render(self, data, phi_shape, traj_lst=None,\n", | |
" interpolation=\"None\", cmap=cm.viridis, vmin=None, vmax=None):\n", | |
" \n", | |
" H, W, D = data.shape\n", | |
" # Setup axes grid\n", | |
" fig = plt.figure(figsize=(W*2, H*2))\n", | |
" gs = gridspec.GridSpec(H, W)\n", | |
" gs.update(wspace=0., hspace=0., left = 0., right = 1., bottom = 0., top = 1.)\n", | |
" axes_grid = {}\n", | |
" \n", | |
" for row in range(H):\n", | |
" for col in range(W):\n", | |
" ax = plt.Subplot(fig, gs[row, col])\n", | |
" ax.imshow(data[row, col].reshape(*phi_shape), vmin=vmin, vmax=vmax)\n", | |
" fig.add_subplot(self._prepare_axis(ax))\n", | |
" axes_grid[self._rowcol_to_xy(row, col)] = ax\n", | |
" \n", | |
" self.add_trajectories(fig, axes_grid, traj_lst)\n", | |
" return fig, ax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(<Figure size 288x288 with 4 Axes>,\n", | |
" <matplotlib.axes._subplots.AxesSubplot at 0x113c27a20>)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATUAAAE1CAYAAACGH3cEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAC81JREFUeJzt3X+M1/V9wPHXV2947iZMDgGtlIJybrRyGgmCS2ytirRbFmpxXbJ1zhhnoouelCV0poJNtqaLVrbWTrKRVdPOrbOItjXcubX/LI1xI4GKVg/kl9Xy69DjvEVx8Nkf7uRQ5L7f7/H9wesej7/0m/eXvAIfn/f6fr7fL5aKogiALE5r9AAAJ5OoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpBKSyWHx5XOKFqjrVazkNxg29vRMnFio8fgFHXolV/uL4rinJHOVRS11miLy0tXVz8VY9rWv5rf6BE4he3oWraznHNefgKpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBpUYNLAQERRNHoMTkDUoAJf+5d/ix/evyqu2fy8uDUpUYMK/NnNfxrfXHhN3PVUt7g1qZZGDwCnkuK006JnzsXx9Cc+Htdufj7ueqo77lzfE3+7aGH8+8dnR5RKjR5xzBO1Ueos9sbtsTE+Fgffe2xHjI8H45LYVJrcwMmopZMVt/m9W2Ll2nXRsXvPe4/1Tp0SK69fHM90zKrV+KmVigpW5/GlicXlpatrOM6p5Y+KF+LGeCEiIoZfwkO/ow/H7PheaXbd52pWWx+Y3+gRaqZ05Ehcu/n5uHP903HktFJZcfvz9T2xdH3Pu88f9vjQ9fONRQvjW4sW1m7oU8yOrmUbiqKYO9I599Sq1FnsjRvjhSjFsRdk/P+/lyLixnghOou99R+Ouhva3H5vWVdZ99zm926Jpet7Tnj9LF3fE/N7t9Rh+lxErUq3x8ayzt1W5jlyKDduK9euK+vXW1HmOY5yT61KH4uDH/gJ+36liJgRB2NW8Xo9Rmp6ra/8stEj1NVrZ58dy//whrh868vx5Sd+GMuf/FGs+eSV8ejvLIiO3XvKun4uGnavjfKIWh10xYZGj9AU3v7XFxs9QsMUEfHR/X3xlXVPxmPzRrwtxCiIWh3cXrqm0SM0ha3L8r5R8GEu3b4j7ux+OiYODsaKJdfHY/Pmxjst/rOrJb+7VdoR40d8CVpExPYYX6+RaCJDMbtwz5548Npr4pZ5Nx0Ts96pU0Z8CVpExEtTp9R81my8UVClB+OSss59u8xz5HDp9h3xnYf+Ib75yHeje87FcdXdy+PRK+Z/YDtbef3isn69e8s8x1E2tSptKk2Oh4vZI35OzQdwx4aRNrP3e6ZjVnxj0cIRP6fmA7iVE7VR+F5pdmwuJsVtsTFmDPtGwfYYH9/2jYIxodKYDfetRQvjv2fOiBVr1x3zLudLU6fEvb5RUDXfKKBuMn2j4P0xG+0bAE/e90D85RduiM3Tzj+JU+ZS7jcKbGpQgTm7dsXSp7qr2syoD38aUIEbnvmv6J5zsZg1MX8qUIGv/MHnGz0CI/CRDiAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSaank8P+c8Xa88YUFtZrllPbIjx6Iv16wJF5sn9boUZrWmt9f3egRmtb0f9wf91y5NgYubm30KE3r6q6YXs65ija1M35jYnXTAIzeznIOefkJpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpQB7++7VCU/rdo2PPHElGDOpixqi/Oe7S/que2vH44LluyK1oOHjnJU+UkalAHu24+O6Y/dCBKb1cepo+ueT32Ljor3pl4eg0my0fUoA4GOlvjzd86I877/sGKntfy+uH4yD+/ETtvm1ijyfIRNaiT7Xe0V7ytDW1pb53/azWcLBdRgzqpdFuzpVVH1KCOKtnWbGnVETWoo3K3NVta9UQN6qycbc2WVj1RgzobaVuzpY2OqEEDnGhbs6WNjqhBA3zYtmZLGz1RgwY53rZmSxs9UYMGef+2dvpBW9rJIGpNbsJbb8bpRw43egxqZGhbiyNFTH1iwJZ2EohakzpvoC/u/tn347F1X4/zB/oaPQ41MrStjdt/OCZ3v2lLOwlErckMxezhH6+KvjPPiiWLl8fOCZMbPRY1tP2O9mjdezgOLDjTlnYStDR6AN513kBf3PTcf8Sndj0XP7joiliyeHn0t7Y1eizqYKCzNfZf1RbbuiY1epQURG2ULvvVllj27ONxQf+eeOTHqyIi4uUJU+K+eZ+LDefOGvH5Yja2/ebPBqPj3n3RtuVQtP90MCIiBmeNi94V58QbV7gOqiFqo3Dzpp64dVN3RESUhj1+Qf+e+PunH4rVndfFms6Fx32umDH97/pi5qp375cOv37athyKS//41djW1R4772hvzHCnMFGr0mW/2hK3buo+5mIcMvTYrZu6Y+PkGcdsbGJGxLsb2sxVfSe8fmau6ov+ua02tgp5o6BKy559vKxzX3p2XUQc/w2Ahy79jKCNUR337ivv3FfLO8dRNrUqXdC/57g/ZYcrRcSF/bvjb376TzF399b4yfQ58eVPfjHeHHdmnDt4IM4dPFCPUZvGWc+91egRmkbblkNlXT9tvYfqMU4qolYHV72yOfadOT4u6ns1Og681uhxGuYjO20d1J6o1cGffLYrbvl5T3QceDUe/sSn44lZl8eh08fe55G+vnJ1o0doGlfN7G30CGm5p1allydMiZH+17JFRGydMDV+MWlaLP30zfEXn7opFrz2Uqx9/Gtxw4v/GeMOv1OPUWlCg7PGlXX9DHaMq8c4qYhale6b97myzt0/b/F7/yxuDOldcU555+4p7xxHiVqVNpw7K1Z3XhdFxAd+4g49trrzuuN+AFfceOOKttjW1X7C62dbV7uPc1TBPbVRWNO5MDZOnhFfenZdXNi/+73Ht06YGvfPWzziNwqG4vbb+1+JW37eEzdu/smYvuc21uy8oz3657ZGx1f3HfMu52DHuOi9xzcKqlUqipFe2R/VNmlaMft376rhOGPbUNyyvqHgjQJG4+qZvRuKopg70jkvP5vIh70s9fepQflErQkNj9sle7f7+9SgAu6pNbFfTJoWd1/5xUaPAacUmxqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqZSKoij/cKm0LyJ21m4ckpserh+qN70oinNGOlRR1ACanZefQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAq/wd9PCIrVuNL8gAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 288x288 with 4 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"m = MatplotlibGridDisplay(2, 2)\n", | |
"m.render(\n", | |
" data = np.array([[0., 0.5], [0.3, 0.9]])[:,:,np.newaxis],\n", | |
" phi_shape=(1,1), traj_lst=[[(1,1), (1,2), (2,2), (2,1), (1,1)]],\n", | |
" vmin=0., vmax=1.)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:irl]", | |
"language": "python", | |
"name": "conda-env-irl-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A workaround. One limitation is that it requires shrinkA = shrinkB.