Last active
September 9, 2016 19:21
-
-
Save mattbierbaum/2202970fcd2d3bd2c431fdde5de522e3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pylab as pl | |
class Encirculator(object): | |
def __init__(self, image, cmap='bone', size=14, max_number_circles=100): | |
""" | |
A class which builds a matplotlib interface to facilitate the drawing | |
of circles on an image (typically to measure the sizes and positions of | |
objects). To create the circles, you label the edges with points via | |
mouse clicks, and a fitted circle is drawn over the image to guide | |
further point placements. Additional circles may be drawn by | |
incrementing / decrementing the current circle counter. | |
In order to operate quickly, there are hot keys. In particular: | |
mouse left : add point | |
mouse right : remove point | |
q : decrement to previous circle | |
w : increment to next circle | |
c : clear all circles | |
e : switch to circle laying mode (mouse left / right click to add / sub) | |
r : switch to normal matplotlib mode (zoom, pan, etc) | |
Parameters: | |
----------- | |
image : 2D numpy array | |
The image to perform circle overlays on top. | |
cmap : string | |
Colormap in which to display the data. | |
size : float | |
Size of the window (settable in order to have proper aspect ratios) | |
""" | |
self.max_number_circles = max_number_circles | |
self.image = image | |
self.shape = image.shape | |
self.cmap = cmap | |
ratio = float(self.shape[0]) / float(self.shape[1]) | |
self.fig = pl.figure(figsize=(14, 14*ratio)) | |
self.ax = self.fig.add_axes([0,0,1,1]) | |
self.points = [[] for i in xrange(self.max_number_circles)] | |
self.curr = 0 | |
self.mode = 'normal' | |
self._calls = [] | |
self.register_events() | |
self.draw(restore_xylim=False) | |
def draw(self, restore_xylim=True): | |
xlim = self.ax.get_xlim() | |
ylim = self.ax.get_ylim() | |
self.ax.cla() | |
self.ax.imshow(self.image, cmap=self.cmap, origin='lower', interpolation='none') | |
self.ax.set_xticks([]) | |
self.ax.set_yticks([]) | |
self.ax.artists = [] | |
for i, lst in enumerate(self.points): | |
if not lst: | |
continue | |
for p in lst: | |
self.ax.plot(p[0], p[1], 'wo', ms=4) | |
params = self.fit_circle(lst) | |
circle0 = pl.Circle((params[0], params[1]), params[2], color='g', fc='none') | |
self.ax.add_artist(circle0) | |
if i == self.curr: | |
circle1 = pl.Circle((params[0], params[1]), params[2], color='none', fc='r', alpha=0.3) | |
self.ax.add_artist(circle1) | |
if restore_xylim: | |
self.ax.set_xlim(xlim) | |
self.ax.set_ylim(ylim) | |
pl.draw() | |
def register_events(self): | |
for c in self._calls: | |
self.fig.canvas.mpl_disconnect(c) | |
self._calls = [] | |
if self.mode == 'normal': | |
self._calls.append(self.fig.canvas.mpl_connect('key_press_event', self.key_press_event)) | |
self._calls.append(self.fig.canvas.mpl_connect('button_press_event', self.mouse_press)) | |
if self.mode == 'nav': | |
self._calls.append(self.fig.canvas.mpl_connect('key_press_event', self.key_press_event)) | |
def _pt(self, event): | |
x0 = event.xdata | |
y0 = event.ydata | |
return np.array([x0, y0]) | |
def mouse_press(self, event): | |
def nearest(pt, pts): | |
return np.sqrt(((np.array(pt) - np.array(pts))**2).sum(axis=-1)).argmin() | |
if event.button == 1: | |
# left click to add new point to current circle | |
self.points[self.curr].append(self._pt(event)) | |
self.draw() | |
if event.button == 2: | |
# right click to remove a point | |
def key_press_event(self, event): | |
self.event = event | |
if event.key == 'q': | |
# decrement the current sphere | |
self.curr -= 1 | |
if event.key == 'w': | |
# increment the current sphere | |
self.curr += 1 | |
self.curr = max(min(self.curr, self.max_number_circles), 0) | |
if event.key == 'c': | |
self.points = [[] for i in xrange(self.max_number_circles)] | |
if event.key == 'r': | |
self.mode = 'nav' | |
print 'Mode is now "nav"' | |
self.register_events() | |
if event.key == 'e': | |
self.mode = 'normal' | |
print 'Mode is now "normal"' | |
self.register_events() | |
self.draw() | |
def fit_circle(self, points): | |
import scipy.optimize as opt | |
def dist2circle(params, pts): | |
x,y,r = params | |
dist = np.sqrt(((pts - np.array([x,y]))**2).sum(axis=-1)) | |
return ((dist - r)**2).sum() | |
points = np.array(points) | |
com = points.mean(axis=0) | |
dist = np.sqrt(((points - com)**2).sum(axis=-1)).std() | |
x0 = np.array([com[0], com[1], dist]) | |
return opt.minimize(dist2circle, x0, args=(points,)).x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment