Last active
May 16, 2020 04:06
-
-
Save ianhi/f54d5dc0066337ae4ad2cff4d3f30029 to your computer and use it in GitHub Desktop.
class to manage a lasso selector for matplotlib in a jupyter notebook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.widgets import LassoSelector | |
from matplotlib.path import Path | |
class image_lasso_selector: | |
def __init__(self, img, mask_alpha=.75, figsize=(10,10)): | |
""" | |
img must have shape (X, Y, 3) | |
""" | |
self.img = img | |
self.mask_alpha = mask_alpha | |
plt.ioff() # see https://github.com/matplotlib/matplotlib/issues/17013 | |
self.fig = plt.figure(figsize=figsize) | |
self.ax = self.fig.gca() | |
self.displayed = self.ax.imshow(img) | |
plt.ion() | |
lineprops = {'color': 'black', 'linewidth': 1, 'alpha': 0.8} | |
self.lasso = LassoSelector(self.ax, self.onselect,lineprops=lineprops, useblit=False) | |
self.lasso.set_visible(True) | |
pix_x = np.arange(self.img.shape[0]) | |
pix_y = np.arange(self.img.shape[1]) | |
xv, yv = np.meshgrid(pix_y,pix_x) | |
self.pix = np.vstack( (xv.flatten(), yv.flatten()) ).T | |
self.mask = np.zeros(self.img.shape[:2]) | |
def onselect(self, verts): | |
self.verts = verts | |
p = Path(verts) | |
self.indices = p.contains_points(self.pix, radius=0).reshape(self.mask.shape) | |
self.draw_with_mask() | |
def draw_with_mask(self): | |
array = self.displayed.get_array().data | |
# https://en.wikipedia.org/wiki/Alpha_compositing#Straight_versus_premultiplied | |
self.mask[self.indices] = 1 | |
c_overlay = self.mask[self.indices][...,None]*[1.,0,0]*self.mask_alpha | |
array[self.indices] = (c_overlay + self.img[self.indices]*(1-self.mask_alpha)) | |
self.displayed.set_data(array) | |
self.fig.canvas.draw_idle() | |
def _ipython_display_(self): | |
display(self.fig.canvas) |
When debugging the class in a notebook I found it necessary to use ipywidgets.Output
widgets to be able to print the errors in the functions used in matplotlib callbacks. i.e:
from ipywidgets import Output
out = Output
class image_lasso_segmenter
....
def draw_with_mask(self):
with out:
array = self.displayed.get_array().data
.....
display(out)
obj = image_lasso_segmenter(img)
obj
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a stripped down version of: https://github.com/ianhi/AC295-final-project-JWI/blob/2bacc09c06228c1eb49130ec5aaeff660f921033/lib/labelling.py#L152
Example usage:
After installing https://github.com/matplotlib/ipympl
Run the following in a jupyter notebook cell (obj needs to be the last line):