Skip to content

Instantly share code, notes, and snippets.

@blink1073
Last active June 18, 2018 18:31
Show Gist options
  • Save blink1073/6ecb48889d3c7526f3c5 to your computer and use it in GitHub Desktop.
Save blink1073/6ecb48889d3c7526f3c5 to your computer and use it in GitHub Desktop.
Improved RectangleSelector and LassoSelector and new EllipseSelector, PaintSelector and LineSelector widgets.
import numpy as np
from matplotlib.widgets import AxesWidget
from matplotlib.patches import Rectangle, Ellipse
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors
from matplotlib.transforms import blended_transform_factory
LABELS_CMAP = mcolors.ListedColormap(['white', 'red', 'dodgerblue', 'gold',
'greenyellow', 'blueviolet'])
class SelectionWidget(AxesWidget):
"""Base class for selection widgets"""
def __init__(self, ax, onselect, useblit=True, button=None):
AxesWidget.__init__(self, ax)
self.visible = True
self.connect_event('motion_notify_event', self._onmove)
self.connect_event('button_press_event', self._press)
self.connect_event('button_release_event', self._release)
self.connect_event('draw_event', self.update_background)
self.connect_event('key_press_event', self._on_key_press)
self.connect_event('scroll_event', self._on_scroll)
# for activation / deactivation
self.active = True
self.background = None
self.artists = []
self.onselect = onselect
self.useblit = useblit and self.canvas.supports_blit
if button is None or isinstance(button, list):
self.validButtons = button
elif isinstance(button, int):
self.validButtons = [button]
# will save the data (position at mouseclick)
self.eventpress = None
# will save the data (pos. at mouserelease)
self.eventrelease = None
def update_background(self, event):
"""force an update of the background"""
if self.useblit:
self.background = self.canvas.copy_from_bbox(self.ax.bbox)
def ignore(self, event):
"""return *True* if *event* should be ignored"""
if not self.active:
return True
# If canvas was locked
if not self.canvas.widgetlock.available(self):
return True
if not hasattr(event, 'button'):
event.button = None
# Only do rectangle selection if event was triggered
# with a desired button
if self.validButtons is not None:
if not event.button in self.validButtons:
return True
# If no button was pressed yet ignore the event if it was out
# of the axes
if self.eventpress is None:
return event.inaxes != self.ax
# If a button was pressed, check if the release-button is the
# same. If event is out of axis, limit the data coordinates to axes
# boundaries.
if event.button == self.eventpress.button and event.inaxes != self.ax:
(xdata, ydata) = self.ax.transData.inverted().transform_point(
(event.x, event.y))
x0, x1 = self.ax.get_xbound()
y0, y1 = self.ax.get_ybound()
xdata = max(x0, xdata)
xdata = min(x1, xdata)
ydata = max(y0, ydata)
ydata = min(y1, ydata)
event.xdata = xdata
event.ydata = ydata
return False
# If a button was pressed, check if the release-button is the
# same.
return (event.inaxes != self.ax or
event.button != self.eventpress.button)
def update(self):
"""draw using newfangled blit or oldfangled draw depending on
useblit
"""
if self.useblit:
if self.background is not None:
self.canvas.restore_region(self.background)
for artist in self.artists:
self.ax.draw_artist(artist)
self.canvas.blit(self.ax.bbox)
else:
self.canvas.draw_idle()
return False
def _press(self, event):
"""Button press event"""
if not self.ignore(event):
self.eventpress = event
self.press(event)
def press(self, event):
"""Button press handler"""
pass
def _release(self, event):
if not self.ignore(event) and not self.eventpress is None:
self.eventrelease = event
# TODO: handle minimum size - add a clear method
self.release(event)
self.eventpress = None
self.eventrelease = None
def release(self, event):
"""Button release event"""
pass
def _onmove(self, event):
if not self.ignore(event):
self.onmove(event)
def onmove(self, event):
"""Cursor motion event"""
pass
def _on_scroll(self, event):
if not self.ignore(event):
self.on_scroll(event)
def on_scroll(self, event):
"""Mouse scroll event"""
pass
def _on_key_press(self, event):
if not self.ignore(event):
self.on_key_press(event)
def on_key_press(self, event):
"""Key press event"""
pass
def set_active(self, active):
"""
Use this to activate / deactivate the Selector
from your program with an boolean parameter *active*.
"""
self.active = active
def get_active(self):
""" Get status of active mode (boolean variable)"""
return self.active
def set_visible(self, visible):
""" Set the visibility of our artists """
for artist in self.artists:
artist.set_visible(visible)
def draw_rubberband(self, x0, x1, y0, y1):
"""Draw a box using the native toolkit given data coordinates
"""
height = self.canvas.figure.bbox.height
x0, y0 = self.ax.transData.transform([x0, y0])
x1, y1 = self.ax.transData.transform([x1, y1])
y1 = height - y1
y0 = height - y0
w = abs(x1 - x0)
h = abs(y1 - y0)
rect = [int(val)for val in (min(x0, x1), min(y0, y1), w, h)]
self.canvas.drawRectangle(rect)
class ToolHandles(object):
"""Control handles for canvas tools.
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
Matplotlib axes where tool handles are displayed.
x, y : 1D arrays
Coordinates of control handles.
marker : str
Shape of marker used to display handle. See `matplotlib.pyplot.plot`.
marker_props : dict
Additional marker properties. See :class:`matplotlib.lines.Line2D`.
"""
def __init__(self, ax, x, y, marker='o', marker_props=None, useblit=True):
self.ax = ax
props = dict(marker=marker, markersize=7, mfc='w', ls='none',
alpha=0.5, visible=False)
props.update(marker_props if marker_props is not None else {})
self._markers = Line2D(x, y, animated=useblit, **props)
self.ax.add_line(self._markers)
self.artist = self._markers
@property
def x(self):
return self._markers.get_xdata()
@property
def y(self):
return self._markers.get_ydata()
def set_data(self, pts, y=None):
"""Set x and y positions of handles"""
if y is not None:
x = pts
pts = np.array([x, y])
self._markers.set_data(pts)
def set_visible(self, val):
self._markers.set_visible(val)
def set_animated(self, val):
self._markers.set_animated(val)
def closest(self, x, y):
"""Return index and pixel distance to closest index."""
pts = np.transpose((self.x, self.y))
# Transform data coordinates to pixel coordinates.
pts = self.ax.transData.transform(pts)
diff = pts - ((x, y))
if diff.ndim == 2:
dist = np.sqrt(np.sum(diff ** 2, axis=1))
return np.argmin(dist), np.min(dist)
else:
return 0, np.sqrt(np.sum(diff ** 2))
class RectangleSelector(SelectionWidget):
_shape_klass = Rectangle
def __init__(self, ax, onselect, drawtype='patch',
minspanx=None, minspany=None, useblit=True,
lineprops=None, rectprops=None, spancoords='data',
button=1, maxdist=10, marker_props=None):
SelectionWidget.__init__(self, ax, onselect=onselect, useblit=useblit,
button=button)
self.to_draw = None
self.visible = True
if drawtype == 'box': # backwards compatibility
drawtype = 'patch'
if drawtype == 'none':
drawtype = 'line' # draw a line but make it
self.visible = False # invisible
if drawtype == 'patch':
if rectprops is None:
rectprops = dict(facecolor='white', edgecolor='black',
alpha=0.5, fill=False)
self.rectprops = rectprops
self.to_draw = self._shape_klass((0, 0),
0, 1, visible=False, **self.rectprops)
self.ax.add_patch(self.to_draw)
if drawtype == 'line':
if lineprops is None:
lineprops = dict(color='black', linestyle='-',
linewidth=2, alpha=0.5)
self.lineprops = lineprops
self.to_draw = Line2D([0, 0, 0, 0, 0], [0, 0, 0, 0, 0], visible=False,
**self.lineprops)
self.ax.add_line(self.to_draw)
self.minspanx = minspanx
self.minspany = minspany
assert(spancoords in ('data', 'pixels'))
self.spancoords = spancoords
self.drawtype = drawtype
self.maxdist = maxdist
if rectprops is None:
props = dict(mec='r')
else:
props = dict(mec=rectprops['edgecolor'])
self._corner_order = ['NW', 'NE', 'SE', 'SW']
xc, yc = self.corners
self._corner_handles = ToolHandles(self.ax, xc, yc, marker_props=props,
useblit=self.useblit)
self._edge_order = ['W', 'N', 'E', 'S']
xe, ye = self.edge_centers
self._edge_handles = ToolHandles(self.ax, xe, ye, marker='s',
marker_props=props, useblit=self.useblit)
xc, yc = self.center
self._center_handle = ToolHandles(self.ax, [xc], [yc], marker='s',
marker_props=props, useblit=self.useblit)
self.artists = [self.to_draw, self._center_handle.artist,
self._corner_handles.artist,
self._edge_handles.artist]
@property
def _rect_bbox(self):
if self.drawtype == 'patch':
x0 = self.to_draw.get_x()
y0 = self.to_draw.get_y()
width = self.to_draw.get_width()
height = self.to_draw.get_height()
return x0, y0, width, height
else:
x, y = self.to_draw.get_data()
x0, x1 = min(x), max(x)
y0, y1 = min(y), max(y)
return x0, y0, x1 - x0, y1 - y0
@property
def corners(self):
"""Corners of rectangle from lower left, moving clockwise."""
x0, y0, width, height = self._rect_bbox
xc = x0, x0 + width, x0 + width, x0
yc = y0, y0, y0 + height, y0 + height
return xc, yc
@property
def edge_centers(self):
"""Midpoint of rectangle edges from left, moving clockwise."""
x0, y0, width, height = self._rect_bbox
w = width / 2.
h = height / 2.
xe = x0, x0 + w, x0 + width, x0 + w
ye = y0 + h, y0, y0 + h, y0 + height
return xe, ye
@property
def center(self):
"""Center of rectangle"""
x0, y0, width, height = self._rect_bbox
return x0 + width / 2., y0 + height / 2.
@property
def extents(self):
"""Return (xmin, xmax, ymin, ymax)."""
x0, y0, width, height = self._rect_bbox
xmin, xmax = sorted([x0, x0 + width])
ymin, ymax = sorted([y0, y0 + height])
return xmin, xmax, ymin, ymax
@extents.setter
def extents(self, extents):
# Update displayed shape
self.draw_shape(extents)
# Update displayed handles
self._corner_handles.set_data(*self.corners)
self._edge_handles.set_data(*self.edge_centers)
self._center_handle.set_data(*self.center)
self.set_visible(self.visible)
if self.eventpress:
self.draw_rubberband(*extents)
def draw_shape(self, extents):
x0, x1, y0, y1 = extents
xmin, xmax = sorted([x0, x1])
ymin, ymax = sorted([y0, y1])
if self.drawtype == 'patch':
self.to_draw.set_x(xmin)
self.to_draw.set_y(ymin)
self.to_draw.set_width(xmax - xmin)
self.to_draw.set_height(ymax - ymin)
elif self.drawtype == 'line':
self.to_draw.set_data([xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin])
def release(self, event):
self._extents_on_press = None
# release coordinates, button, ...
self.eventrelease = event
if self.spancoords == 'data':
xmin, ymin = self.eventpress.xdata, self.eventpress.ydata
xmax, ymax = self.eventrelease.xdata, self.eventrelease.ydata
# calculate dimensions of box or line get values in the right
# order
elif self.spancoords == 'pixels':
xmin, ymin = self.eventpress.x, self.eventpress.y
xmax, ymax = self.eventrelease.x, self.eventrelease.y
else:
raise ValueError('spancoords must be "data" or "pixels"')
if xmin > xmax:
xmin, xmax = xmax, xmin
if ymin > ymax:
ymin, ymax = ymax, ymin
spanx = xmax - xmin
spany = ymax - ymin
xproblems = self.minspanx is not None and spanx < self.minspanx
yproblems = self.minspany is not None and spany < self.minspany
if (self.drawtype in ['patch', 'line'] and (xproblems or yproblems)):
# check if drawn distance (if it exists) is not too small in
# neither x nor y-direction
return
# update the eventpress and eventrelease with the resulting extents
x1, x2, y1, y2 = self.extents
self.eventpress.xdata = x1
self.eventpress.ydata = y1
xy1 = self.ax.transData.transform_point([x1, y1])
self.eventpress.x, self.eventpress.y = xy1
self.eventrelease.xdata = x2
self.eventrelease.ydata = y2
xy2 = self.ax.transData.transform_point([x2, y2])
self.eventrelease.x, self.eventrelease.y = xy2
self.onselect(self.eventpress, self.eventrelease)
# call desired function
self.update()
return False
def press(self, event):
"""on button press event"""
# make the drawed box/line visible get the click-coordinates,
# button, ...
self.set_visible(self.visible)
self._set_active_handle(event)
if self.active_handle is None:
# Clear previous rectangle before drawing new rectangle.
self.set_visible(False)
self.update()
self.set_visible(self.visible)
def _set_active_handle(self, event):
"""Set active handle based on the location of the mouse event"""
# Note: event.xdata/ydata in data coordinates, event.x/y in pixels
c_idx, c_dist = self._corner_handles.closest(event.x, event.y)
e_idx, e_dist = self._edge_handles.closest(event.x, event.y)
m_idx, m_dist = self._center_handle.closest(event.x, event.y)
if event.key in ['alt', ' ']:
self.active_handle = 'C'
self._extents_on_press = self.extents
# Set active handle as closest handle, if mouse click is close enough.
elif m_dist < self.maxdist:
self.active_handle = 'C'
elif c_dist > self.maxdist and e_dist > self.maxdist:
self.active_handle = None
return
elif c_dist < e_dist:
self.active_handle = self._corner_order[c_idx]
else:
self.active_handle = self._edge_order[e_idx]
# Save coordinates of rectangle at the start of handle movement.
x1, x2, y1, y2 = self.extents
# Switch variables so that only x2 and/or y2 are updated on move.
if self.active_handle in ['W', 'SW', 'NW']:
x1, x2 = x2, event.xdata
if self.active_handle in ['N', 'NW', 'NE']:
y1, y2 = y2, event.ydata
self._extents_on_press = x1, x2, y1, y2
def onmove(self, event):
if self.eventpress is None:
return
key = self.eventpress.key or ''
# resize an existing shape
if self.active_handle and not self.active_handle == 'C':
x1, x2, y1, y2 = self._extents_on_press
if self.active_handle in ['E', 'W'] + self._corner_order:
x2 = event.xdata
if self.active_handle in ['N', 'S'] + self._corner_order:
y2 = event.ydata
# move existing shape
elif self.active_handle == 'C':
x1, x2, y1, y2 = self._extents_on_press
dx = event.xdata - self.eventpress.xdata
dy = event.ydata - self.eventpress.ydata
x1 += dx
x2 += dx
y1 += dy
y2 += dy
# new shape
else:
center = [self.eventpress.xdata, self.eventpress.ydata]
center_pix = [self.eventpress.x, self.eventpress.y]
dx = (event.xdata - center[0]) / 2.
dy = (event.ydata - center[1]) / 2.
# square shape
if 'shift' in key:
dx_pix = abs(event.x - center_pix[0])
dy_pix = abs(event.y - center_pix[1])
if not dx_pix:
return
maxd = max(abs(dx_pix), abs(dy_pix))
if abs(dx_pix) < maxd:
dx *= maxd / abs(dx_pix)
if abs(dy_pix) < maxd:
dy *= maxd / abs(dy_pix)
# from center
if key == 'control' or key == 'ctrl+shift':
dx *= 2
dy *= 2
# from corner
else:
center[0] += dx
center[1] += dy
x1, x2, y1, y2 = (center[0] - dx, center[0] + dx,
center[1] - dy, center[1] + dy)
self.extents = x1, x2, y1, y2
class EllipseSelector(RectangleSelector):
_shape_klass = Ellipse
def draw_shape(self, extents):
x1, x2, y1, y2 = extents
xmin, xmax = sorted([x1, x2])
ymin, ymax = sorted([y1, y2])
center = [x1 + (x2 - x1) / 2., y1 + (y2 - y1) / 2.]
a = (xmax - xmin) / 2.
b = (ymax - ymin) / 2.
if self.drawtype == 'patch':
self.to_draw.center = center
self.to_draw.width = 2 * a
self.to_draw.height = 2 * b
else:
rad = np.arange(31) * 12 * np.pi / 180
x = a * np.cos(rad) + center[0]
y = b * np.sin(rad) + center[1]
self.to_draw.set_data(x, y)
@property
def _rect_bbox(self):
if self.drawtype == 'patch':
x, y = self.to_draw.center
width = self.to_draw.width
height = self.to_draw.height
return x - width / 2., y - height / 2., width, height
else:
x, y = self.to_draw.get_data()
x0, x1 = min(x), max(x)
y0, y1 = min(y), max(y)
return x0, y0, x1 - x0, y1 - y0
@property
def geometry(self):
x0, y0, width, height = self._rect_bbox
return x0 + width / 2., y0 + width / 2., width, height
class LassoSelector(SelectionWidget):
"""Selection curve of an arbitrary shape.
"""
def __init__(self, ax, onselect, useblit=True, button=None,
lineprops=None):
SelectionWidget.__init__(self, ax, onselect=onselect, useblit=useblit,
button=button)
self.verts = None
if lineprops is None:
lineprops = dict()
self.line = Line2D([], [], **lineprops)
self.line.set_visible(False)
self.ax.add_line(self.line)
self.artists = [self.line]
def press(self, event):
if not event.key == 'shift' or self.verts is None:
self.verts = [(event.xdata, event.ydata)]
self.line.set_visible(True)
def finish(self, event):
self.verts.append(self.verts[0])
self.line.set_data(list(zip(*self.verts)))
self.update()
self.onselect(self.verts)
self.verts = None
def release(self, event):
self.verts.append((event.xdata, event.ydata))
if event.key != 'shift':
self.finish(event)
else:
self.verts.append((event.xdata, event.ydata))
self.line.set_data(list(zip(*self.verts)))
self.update()
def onmove(self, event):
if self.verts is None:
return
if event.key == 'shift':
self.verts[-1] = [event.xdata, event.ydata]
elif event.button:
self.verts.append((event.xdata, event.ydata))
else:
return self.finish(event)
self.line.set_data(list(zip(*self.verts)))
self.update()
class LineSelector(SelectionWidget):
def __init__(self, ax, onselect, useblit=True, button=1,
maxdist=10, line_props=None):
super(LineSelector, self).__init__(ax, onselect,
useblit=useblit, button=button)
props = dict(color='r', linewidth=1, alpha=0.4, solid_capstyle='butt')
props.update(line_props if line_props is not None else {})
self.linewidth = props['linewidth']
self.maxdist = maxdist
self._active_pt = None
x = (0, 0)
y = (0, 0)
self._end_pts = np.transpose([x, y])
self._line = Line2D(x, y, visible=False, animated=True, **props)
self.ax.add_line(self._line)
self._handles = ToolHandles(self.ax, x, y, useblit=useblit)
self._handles.set_visible(False)
self.artists = [self._line, self._handles.artist]
@property
def end_points(self):
return self._end_pts.astype(int)
@end_points.setter
def end_points(self, pts):
self._end_pts = pts = np.asarray(pts)
self._line.set_data(np.transpose(pts))
self._line.set_linewidth(self.linewidth)
self._center = center = (pts[1] + pts[0]) / 2.
handle_pts = np.vstack((pts[0], center, pts[1])).T
self._handles.set_data(handle_pts)
self.set_visible(True)
self.update()
def press(self, event):
idx, px_dist = self._handles.closest(event.x, event.y)
if px_dist < self.maxdist:
self._active_pt = idx
else:
self._active_pt = None
if event.key in ['alt', ' ']:
self._active_pt = 1
self.set_visible(True)
if self._active_pt is None:
self._active_pt = 0
x, y = event.xdata, event.ydata
self._end_pts = np.array([[x, y], [x, y]])
def release(self, event):
self._active_pt = None
self.onselect(self.geometry)
def onmove(self, event):
if self._active_pt is None:
return
self.update_data(event.xdata, event.ydata)
def update_data(self, x=None, y=None):
if x is not None:
# check for center
if self._active_pt == 1:
xc, yc = self._center
xo, yo = x - xc, y - yc
self._end_pts += [xo, yo]
elif self._active_pt == 0:
self._end_pts[0, :] = x, y
else:
self._end_pts[1, :] = x, y
self.end_points = self._end_pts
@property
def geometry(self):
return self.end_points
def on_scroll(self, event):
if event.button == 'up':
self._thicken_scan_line()
elif event.button == 'down':
self._shrink_scan_line()
def on_key_press(self, event):
if event.key == '+':
self._thicken_scan_line()
elif event.key == '-':
self._shrink_scan_line()
def _thicken_scan_line(self):
self.linewidth += 1
self.update_data()
def _shrink_scan_line(self):
if self.linewidth > 1:
self.linewidth -= 1
self.update_data()
class PaintSelector(SelectionWidget):
def __init__(self, ax, onselect, overlay_shape, radius=5, alpha=0.3,
rect_props=None, useblit=False, button=None):
super(PaintSelector, self).__init__(ax, onselect,
useblit=useblit, button=button)
props = dict(edgecolor='r', facecolor='0.7', alpha=0.5,
animated=useblit)
props.update(rect_props if rect_props is not None else {})
self.alpha = alpha
self.cmap = LABELS_CMAP
self._overlay_plot = None
self._radius = radius
self._cursor = [0, 0, 0]
# These can only be set after initializing `_cursor`
self.shape = overlay_shape
self.label = 1
self.radius = radius
# Note that the order is important: Redraw cursor *after* overlay
self.artists = [self._overlay_plot]
@property
def label(self):
return self._label
@label.setter
def label(self, value):
if value >= self.cmap.N:
raise ValueError('Maximum label value = %s' % len(self.cmap - 1))
self._label = value
@property
def radius(self):
return self._radius
@radius.setter
def radius(self, r):
self._radius = r
self._width = 2 * r + 1
[x, y, _] = self._cursor
self._cursor = [x, y, self._width]
self.window = CenteredWindow(r, self._shape)
self.update()
@property
def overlay(self):
return self._overlay
@overlay.setter
def overlay(self, image):
self._overlay = image
if image is None:
self.ax.images.remove(self._overlay_plot)
self._overlay_plot = None
elif self._overlay_plot is None:
props = dict(cmap=self.cmap, alpha=self.alpha,
norm=mcolors.NoNorm(), animated=True)
self._overlay_plot = self.ax.imshow(image, **props)
else:
self._overlay_plot.set_data(image)
self._shape = image.shape
# this triggers an update
self.radius = self._radius
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
if not self._overlay_plot is None:
self._overlay_plot.set_extent((-0.5, shape[1] + 0.5,
shape[0] + 0.5, -0.5))
self.radius = self._radius
self.overlay = np.zeros(shape, dtype='uint8')
def press(self, event):
self.update_cursor(event.xdata, event.ydata)
self.update_overlay(event.xdata, event.ydata)
def release(self, event):
self.onselect(self.geometry)
def onmove(self, event):
self.update_cursor(event.xdata, event.ydata)
if not self.eventpress:
x, y, r = self._cursor
self.draw_rubberband(x, x + r, y, y + r)
return
self.update_overlay(event.xdata, event.ydata)
def on_scroll(self, event):
if event.button == 'up':
self.radius += 1
elif event.button == 'down':
self.radius = max(self.radius - 1, 1)
def on_key_press(self, event):
if event.key == '+':
self.radius += 1
elif event.key == '-':
self.radius = max(self.radius - 1, 1)
def update_overlay(self, x, y):
overlay = self.overlay
overlay[self.window.at(y, x)] = self.label
# Note that overlay calls `update`
self.overlay = overlay
def update_cursor(self, x, y):
x = x - self.radius - 1
y = y - self.radius - 1
self._cursor = [x, y, self._width]
@property
def geometry(self):
return self.overlay
class CenteredWindow(object):
"""Window that create slices numpy arrays over 2D windows.
Examples
--------
>>> a = np.arange(16).reshape(4, 4)
>>> w = CenteredWindow(1, a.shape)
>>> a[w.at(1, 1)]
array([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
>>> a[w.at(0, 0)]
array([[0, 1],
[4, 5]])
>>> a[w.at(4, 3)]
array([[14, 15]])
"""
def __init__(self, radius, array_shape):
self.radius = radius
self.array_shape = array_shape
def at(self, row, col):
h, w = self.array_shape
r = self.radius
xmin = max(0, col - r)
xmax = min(w, col + r + 1)
ymin = max(0, row - r)
ymax = min(h, row + r + 1)
return [slice(ymin, ymax), slice(xmin, xmax)]
class SpanSelector(SelectionWidget):
"""
Select a min/max range of the x or y axes for a matplotlib Axes
Example usage::
ax = subplot(111)
ax.plot(x,y)
def onselect(vmin, vmax):
print vmin, vmax
span = SpanSelector(ax, onselect, 'horizontal')
*onmove_callback* is an optional callback that is called on mouse
move within the span range
"""
def __init__(self, ax, onselect, direction, minspan=None, useblit=False,
rectprops=None, onmove_callback=None, button=1):
"""
Create a span selector in *ax*. When a selection is made, clear
the span and call *onselect* with::
onselect(vmin, vmax)
and clear the span.
*direction* must be 'horizontal' or 'vertical'
If *minspan* is not *None*, ignore events smaller than *minspan*
The span rectangle is drawn with *rectprops*; default::
rectprops = dict(facecolor='red', alpha=0.5)
Set the visible attribute to *False* if you want to turn off
the functionality of the span selector
"""
'''CHANGES: no more newaxis (what was that for?)
Do not ignore when invisible - that is what active is for - may still want updates
TODO: allow the user to hold shift and move the cursor
'''
SelectionWidget.__init__(self, ax, onselect, button=button, useblit=useblit)
if rectprops is None:
rectprops = dict(facecolor='red', alpha=0.5)
assert direction in ['horizontal', 'vertical'], 'Must choose horizontal or vertical for direction'
self.direction = direction
self.pressv = None
self.rectprops = rectprops
self.onmove_callback = onmove_callback
self.minspan = minspan
# Needed when dragging out of axes
self.prev = (0, 0)
if self.direction == 'horizontal':
trans = blended_transform_factory(self.ax.transData,
self.ax.transAxes)
w, h = 0, 1
else:
trans = blended_transform_factory(self.ax.transAxes,
self.ax.transData)
w, h = 1, 0
self.rect = Rectangle((0, 0), w, h,
transform=trans,
visible=False,
**self.rectprops)
if not self.useblit:
self.ax.add_patch(self.rect)
self.artists = [self.rect]
def press(self, event):
"""on button press event"""
self.rect.set_visible(self.visible)
if self.direction == 'horizontal':
self.pressv = event.xdata
else:
self.pressv = event.ydata
return False
def release(self, event):
"""on button release event"""
self.rect.set_visible(self.visible)
self.update()
vmin = self.pressv
if self.direction == 'horizontal':
vmax = event.xdata or self.prev[0]
else:
vmax = event.ydata or self.prev[1]
if vmin > vmax:
vmin, vmax = vmax, vmin
span = vmax - vmin
if self.minspan is not None and span < self.minspan:
return
self.onselect(vmin, vmax)
return False
def onmove(self, event):
"""on motion notify event"""
if not self.eventpress:
return
x, y = event.xdata, event.ydata
self.prev = x, y
if self.direction == 'horizontal':
v = x
else:
v = y
if self.eventpress.key in ['alt', 'shift', ' ']:
# center the window where the cursor is
# make sure to handle onmove_callback as well
return
minv, maxv = v, self.pressv
if minv > maxv:
minv, maxv = maxv, minv
if self.direction == 'horizontal':
self.rect.set_x(minv)
self.rect.set_width(maxv - minv)
else:
self.rect.set_y(minv)
self.rect.set_height(maxv - minv)
if self.onmove_callback is not None:
vmin = self.pressv
if self.direction == 'horizontal':
vmax = event.xdata or self.prev[0]
else:
vmax = event.ydata or self.prev[1]
if vmin > vmax:
vmin, vmax = vmax, vmin
self.onmove_callback(vmin, vmax)
if self.eventpress:
if self.direction == 'horizontal':
bound = self.ax.get_ybound()
self.draw_rubberband(minv, maxv, bound[0], bound[1])
else:
bound = self.ax.get_xbound()
self.draw_rubberband(bound[0], bound[1], minv, maxv)
return False
if __name__ == '__main__':
import matplotlib.pyplot as plt
from numpy.random import rand
def onselect(*args):
print(args)
img = rand(100, 100)
plt.imshow(img, cmap='winter')
#rs = RectangleSelector(plt.gca(), onselect)
#es = EllipseSelector(plt.gca(), onselect)
#ls = LassoSelector(plt.gca(), onselect)
#ls = LineSelector(plt.gca(), onselect)
ps = PaintSelector(plt.gca(), onselect, img.shape)
#ss = SpanSelector(plt.gca(), onselect, 'horizontal')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment