Created
February 28, 2023 19:24
-
-
Save manzt/b5c6c2d05a11a86862e2e68dc7a6461e to your computer and use it in GitHub Desktop.
live matplotlib views with jupyter-scatter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "3f71b5d6-2128-42d9-a444-53203ae5a42f", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>mass</th>\n", | |
" <th>speed</th>\n", | |
" <th>pval</th>\n", | |
" <th>group</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.843422</td>\n", | |
" <td>0.939455</td>\n", | |
" <td>0.828880</td>\n", | |
" <td>B</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0.956246</td>\n", | |
" <td>0.708909</td>\n", | |
" <td>0.820758</td>\n", | |
" <td>A</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0.588491</td>\n", | |
" <td>0.356468</td>\n", | |
" <td>0.664645</td>\n", | |
" <td>A</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0.441721</td>\n", | |
" <td>0.405219</td>\n", | |
" <td>0.382708</td>\n", | |
" <td>A</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0.598023</td>\n", | |
" <td>0.943465</td>\n", | |
" <td>0.397512</td>\n", | |
" <td>A</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" mass speed pval group\n", | |
"0 0.843422 0.939455 0.828880 B\n", | |
"1 0.956246 0.708909 0.820758 A\n", | |
"2 0.588491 0.356468 0.664645 A\n", | |
"3 0.441721 0.405219 0.382708 A\n", | |
"4 0.598023 0.943465 0.397512 A" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%matplotlib widget\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"\n", | |
"import jscatter\n", | |
"import traitlets\n", | |
"import ipywidgets\n", | |
"\n", | |
"# Create some example data\n", | |
"\n", | |
"data = np.random.rand(500, 4)\n", | |
"df = pd.DataFrame(data, columns=['mass', 'speed', 'pval', 'group'])\n", | |
"df['group'] = df['group'].map(lambda c: chr(65 + round(c)), na_action=None)\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "2cd08802-0951-4092-a023-f0ffd82b2b70", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# create a scatter plot\n", | |
"scatter = jscatter.Scatter(x=\"mass\", y=\"speed\", data=df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c80934c7-896c-4e31-8f12-1b8bdfecdfd4", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<traitlets.traitlets.directional_link at 0x10e0fe790>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Wire up the Scatter plot with additional \"pieces\" of derived state\n", | |
"\n", | |
"# create a new piece of state for the Jupyter Scatter widget\n", | |
"scatter.widget.add_traits(\n", | |
" selection_df=traitlets.Any(allow_none=True),\n", | |
" selection_summary=traitlets.Any(allow_none=True),\n", | |
")\n", | |
"\n", | |
"# specify a directional link such that any time `selection` changes (the indices)\n", | |
"# we update this summary\n", | |
"traitlets.dlink(\n", | |
" source=(scatter.widget, \"selection\"),\n", | |
" target=(scatter.widget, \"selection_df\"),\n", | |
" transform=lambda sel: scatter._data.iloc[sel]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "c70d5f89-d9a6-41d8-8df6-68acb11d8870", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6d0d3d28cb7342f8b0edc0a10d013091", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(HBox(children=(VBox(children=(Button(icon='arrows', layout=Layout(width='36px'), style=ButtonSt…" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"class ECDF(traitlets.HasTraits):\n", | |
" data = traitlets.Any(allow_none=True)\n", | |
" \n", | |
" def __init__(self, x: str):\n", | |
" self.x = x\n", | |
" with plt.ioff():\n", | |
" self.fig, self.ax = plt.subplots(1, 1) \n", | |
" self.fig.canvas.header_visible = False\n", | |
" super().__init__()\n", | |
" \n", | |
" @traitlets.observe(\"data\")\n", | |
" def _on_data_change(self, change):\n", | |
" self.ax.clear()\n", | |
" sns.ecdfplot(x=self.x, data=change[\"new\"], ax=self.ax)\n", | |
" self.fig.canvas.draw_idle()\n", | |
" self.fig.canvas.flush_events()\n", | |
"\n", | |
" \n", | |
"ecdf_mass = ECDF(x=\"mass\")\n", | |
"ecdf_speed = ECDF(x=\"speed\")\n", | |
"traitlets.dlink((scatter.widget, \"selection_df\"), (ecdf_mass, \"data\"))\n", | |
"traitlets.dlink((scatter.widget, \"selection_df\"), (ecdf_speed, \"data\"))\n", | |
"\n", | |
"ecdf_mass.fig.canvas.layout.width = \"30%\"\n", | |
"ecdf_speed.fig.canvas.layout.width = \"30%\"\n", | |
"\n", | |
"ipywidgets.HBox([\n", | |
" scatter.show(),\n", | |
" ecdf_mass.fig.canvas,\n", | |
" ecdf_speed.fig.canvas,\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "57ff9ee8-81a0-4b2e-9529-bffa47d1a36c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f7daf6bd-6054-401f-8b4e-ac52d0d81754", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment