Skip to content

Instantly share code, notes, and snippets.

@alexlib
Forked from arwoll/mask_creator.py
Last active January 6, 2021 23:25
Show Gist options
  • Save alexlib/11fb4f7ca6e3d8a69c989d0ced614e96 to your computer and use it in GitHub Desktop.
Save alexlib/11fb4f7ca6e3d8a69c989d0ced614e96 to your computer and use it in GitHub Desktop.
Tool to create polygon mask in Matplotlib
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# updated from: https://gist.github.com/arwoll/295b6b821c91666714ecd5031144620d \n",
"# see previous versions on gist.github.com"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Interactive tool to draw mask on an image or image-like array.\n",
"Adapted from matplotlib/examples/event_handling/poly_editor.py\n",
"\"\"\"\n",
"import numpy as np\n",
"\n",
"import matplotlib as mpl\n",
"mpl.use('tkagg')\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Polygon\n",
"# dist_point_to_segment was deprecated. Included directly below \n",
"# from matplotlib.mlab import dist_point_to_segment\n",
"from matplotlib.path import Path\n",
"\n",
"class MaskCreator(object):\n",
" \"\"\"An interactive polygon editor.\n",
" Parameters\n",
" ----------\n",
" poly_xy : list of (float, float)\n",
" List of (x, y) coordinates used as vertices of the polygon.\n",
" max_ds : float\n",
" Max pixel distance to count as a vertex hit.\n",
" Key-bindings\n",
" ------------\n",
" 't' : toggle vertex markers on and off. When vertex markers are on,\n",
" you can move them, delete them\n",
" 'd' : delete the vertex under point\n",
" 'i' : insert a vertex at point. You must be within max_ds of the\n",
" line connecting two existing vertices\n",
" \"\"\"\n",
" def __init__(self, ax, poly_xy=None, max_ds=10):\n",
" self.showverts = True\n",
" self.max_ds = max_ds\n",
" if poly_xy is None:\n",
" poly_xy = default_vertices(ax)\n",
" self.poly = Polygon(poly_xy, animated=True,\n",
" fc='y', ec='none', alpha=0.4)\n",
"\n",
" ax.add_patch(self.poly)\n",
" ax.set_clip_on(False)\n",
" ax.set_title(\"Click and drag a point to move it; \"\n",
" \"'i' to insert; 'd' to delete.\\n\"\n",
" \"Close figure when done.\")\n",
" self.ax = ax\n",
"\n",
" x, y = zip(*self.poly.xy)\n",
" self.line = plt.Line2D(x, y, color='none', marker='o', mfc='r',\n",
" alpha=0.2, animated=True)\n",
" self._update_line()\n",
" self.ax.add_line(self.line)\n",
"\n",
" self.poly.add_callback(self.poly_changed)\n",
" self._ind = None # the active vert\n",
"\n",
" canvas = self.poly.figure.canvas\n",
" canvas.mpl_connect('draw_event', self.draw_callback)\n",
" canvas.mpl_connect('button_press_event', self.button_press_callback)\n",
" canvas.mpl_connect('button_release_event', self.button_release_callback)\n",
" canvas.mpl_connect('key_press_event', self.key_press_callback)\n",
" canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)\n",
" self.canvas = canvas\n",
"\n",
" def get_mask(self, shape):\n",
" \"\"\"Return image mask given by mask creator\"\"\"\n",
" h, w = shape\n",
" y, x = np.mgrid[:h, :w]\n",
" points = np.transpose((x.ravel(), y.ravel()))\n",
"\n",
" path = Path(self.verts)\n",
" mask = path.contains_points( points )\n",
"\n",
" return mask.reshape(h, w)\n",
"\n",
" def poly_changed(self, poly):\n",
" 'this method is called whenever the polygon object is called'\n",
" # only copy the artist props to the line (except visibility)\n",
" vis = self.line.get_visible()\n",
" #Artist.update_from(self.line, poly)\n",
" self.line.set_visible(vis) # don't use the poly visibility state\n",
"\n",
" def draw_callback(self, event):\n",
" self.background = self.canvas.copy_from_bbox(self.ax.bbox)\n",
" self.ax.draw_artist(self.poly)\n",
" self.ax.draw_artist(self.line)\n",
" self.canvas.blit(self.ax.bbox)\n",
"\n",
" def button_press_callback(self, event):\n",
" 'whenever a mouse button is pressed'\n",
" ignore = not self.showverts or event.inaxes is None or event.button != 1\n",
" if ignore:\n",
" return\n",
" self._ind = self.get_ind_under_cursor(event)\n",
"\n",
" def button_release_callback(self, event):\n",
" 'whenever a mouse button is released'\n",
" ignore = not self.showverts or event.button != 1\n",
" if ignore:\n",
" return\n",
" self._ind = None\n",
"\n",
" def key_press_callback(self, event):\n",
" 'whenever a key is pressed'\n",
" if not event.inaxes:\n",
" return\n",
" if event.key=='t':\n",
" self.showverts = not self.showverts\n",
" self.line.set_visible(self.showverts)\n",
" if not self.showverts:\n",
" self._ind = None\n",
" elif event.key=='d':\n",
" ind = self.get_ind_under_cursor(event)\n",
" if ind is None:\n",
" return\n",
" if ind == 0 or ind == self.last_vert_ind:\n",
" print(\"Cannot delete root node\")\n",
" return\n",
" self.poly.xy = [tup for i,tup in enumerate(self.poly.xy)\n",
" if i!=ind]\n",
" self._update_line()\n",
" elif event.key=='i':\n",
" xys = self.poly.get_transform().transform(self.poly.xy)\n",
" p = event.x, event.y # cursor coords\n",
" for i in range(len(xys)-1):\n",
" s0 = xys[i]\n",
" s1 = xys[i+1]\n",
" d = dist_point_to_segment(p, s0, s1)\n",
" if d <= self.max_ds:\n",
" self.poly.xy = np.array(\n",
" list(self.poly.xy[:i+1]) +\n",
" [(event.xdata, event.ydata)] +\n",
" list(self.poly.xy[i+1:]))\n",
" self._update_line()\n",
" break\n",
" self.canvas.draw()\n",
"\n",
" def motion_notify_callback(self, event):\n",
" 'on mouse movement'\n",
" ignore = (not self.showverts or event.inaxes is None or\n",
" event.button != 1 or self._ind is None)\n",
" if ignore:\n",
" return\n",
" x,y = event.xdata, event.ydata\n",
"\n",
" if self._ind == 0 or self._ind == self.last_vert_ind:\n",
" self.poly.xy[0] = x,y\n",
" self.poly.xy[self.last_vert_ind] = x,y\n",
" else:\n",
" self.poly.xy[self._ind] = x,y\n",
" self._update_line()\n",
"\n",
" self.canvas.restore_region(self.background)\n",
" self.ax.draw_artist(self.poly)\n",
" self.ax.draw_artist(self.line)\n",
" self.canvas.blit(self.ax.bbox)\n",
"\n",
" def _update_line(self):\n",
" # save verts because polygon gets deleted when figure is closed\n",
" self.verts = self.poly.xy\n",
" self.last_vert_ind = len(self.poly.xy) - 1\n",
" self.line.set_data(zip(*self.poly.xy))\n",
"\n",
" def get_ind_under_cursor(self, event):\n",
" 'get the index of the vertex under cursor if within max_ds tolerance'\n",
" # display coords\n",
" xy = np.asarray(self.poly.xy)\n",
" xyt = self.poly.get_transform().transform(xy)\n",
" xt, yt = xyt[:, 0], xyt[:, 1]\n",
" d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)\n",
" indseq = np.nonzero(np.equal(d, np.amin(d)))[0]\n",
" ind = indseq[0]\n",
" if d[ind] >= self.max_ds:\n",
" ind = None\n",
" return ind\n",
"\n",
"\n",
"def default_vertices(ax):\n",
" \"\"\"Default to rectangle that has a quarter-width/height border.\"\"\"\n",
" xlims = ax.get_xlim()\n",
" ylims = ax.get_ylim()\n",
" w = np.diff(xlims)\n",
" h = np.diff(ylims)\n",
" x1, x2 = xlims + w // 4 * np.array([1, -1])\n",
" y1, y2 = ylims + h // 4 * np.array([1, -1])\n",
" return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))\n",
"\n",
"\n",
"def mask_creator_demo():\n",
" img = np.random.uniform(0, 255, size=(100, 100))\n",
" ax = plt.subplot(111)\n",
" ax.imshow(img)\n",
"\n",
" mc = MaskCreator(ax)\n",
" plt.show()\n",
"\n",
" mask = mc.get_mask(img.shape)\n",
" img[~mask] = np.uint8(np.clip(img[~mask] - 100., 0, 255))\n",
" plt.imshow(img)\n",
" plt.title('Region outside of mask is darkened')\n",
" plt.show()\n",
" \n",
"def dist(x, y):\n",
" \"\"\"\n",
" Return the distance between two points.\n",
" \"\"\"\n",
" d = x - y\n",
" return np.sqrt(np.dot(d, d))\n",
"\n",
"def dist_point_to_segment(p, s0, s1):\n",
" \"\"\"\n",
" Get the distance of a point to a segment.\n",
" *p*, *s0*, *s1* are *xy* sequences\n",
" This algorithm from\n",
" http://geomalgorithms.com/a02-_lines.html\n",
" \"\"\"\n",
" v = s1 - s0\n",
" w = p - s0\n",
" c1 = np.dot(w, v)\n",
" if c1 <= 0:\n",
" return dist(p, s0)\n",
" c2 = np.dot(v, v)\n",
" if c2 <= c1:\n",
" return dist(p, s1)\n",
" b = c1 / c2\n",
" pb = s0 + b * v\n",
" return dist(p, pb)\n",
"\n",
"# if __name__ == '__main__':\n",
"# mask_creator_demo()\n",
"\n",
"def rgb2gray(rgb):\n",
"\n",
" r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]\n",
" gray = 0.2989 * r + 0.5870 * g + 0.1140 * b\n",
"\n",
" return gray"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# img = np.random.uniform(0, 255, size=(100, 100))\n",
"fig, ax = plt.subplots()\n",
"img = rgb2gray(plt.imread('103019596-babffb00-450c-11eb-86a3-0a3c96d54007.png'))\n",
"\n",
"\n",
"# ax = plt.subplot(111)\n",
"ax.imshow(img)\n",
"\n",
"mc = MaskCreator(ax)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mask = mc.get_mask(img.shape)\n",
"img[mask] = np.uint8(np.clip(img[mask] - 255, 0, 255))\n",
"plt.figure()\n",
"plt.imshow(img)\n",
"plt.title('Region outside of mask is darkened')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:openpiv] *",
"language": "python",
"name": "conda-env-openpiv-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment