Skip to content

Instantly share code, notes, and snippets.

@blink1073
Last active January 3, 2016 20:45
Show Gist options
  • Save blink1073/456f8912d5079738ce1b to your computer and use it in GitHub Desktop.
Save blink1073/456f8912d5079738ce1b to your computer and use it in GitHub Desktop.
Matplotlib Image Histogram Selector
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from matplotlib.widgets import SpanSelector
class HistSelector(object):
def __init__(self, hist_ax, im, callback=None):
self.im = im
self.hist_ax = hist_ax
self.canvas = hist_ax.figure.canvas
self.callback = callback
self.patches = []
self.span = None
self.sel = SpanSelector(hist_ax, self.onselect,
'vertical', useblit=True,
rectprops = dict(facecolor='b', alpha=0.15),
span_stays=True)
hist_ax.format_coord = self.format_coord
hist_ax.set_axis_bgcolor(hist_ax.figure.get_facecolor())
self.update()
hist_ax.set_xticklabels([])
hist_ax.set_yticklabels([])
hist_ax.xaxis.set_major_locator(plt.NullLocator())
hist_ax.yaxis.set_major_locator(plt.NullLocator())
for spine in self.hist_ax.spines:
hist_ax.spines[spine].set_visible(False)
def format_coord(self, x, y):
if self.im.get_array().dtype.kind == 'f':
return 'z=%1.4f' % y
else:
return 'z=%d' % y
def onselect(self, start, end):
start = min(start, end)
end = max(start, end)
ylim = self.hist_ax.get_ylim()
start = max(start, ylim[0])
end = min(end, ylim[1])
if ((end - start) / (ylim[1] - ylim[0]) < 0.01 or
(end - start) / (ylim[1] - ylim[0]) > 0.99):
start, end = ylim
self.sel.rect.set_visible(False)
self.sel.stay_rect.set_visible(False)
cbar = self.im.colorbar
self.span = (start, end)
cbar.norm.vmin = start
cbar.norm.vmax = end
cbar.draw_all()
self.im.set_norm(cbar.norm)
self.canvas.draw_idle()
if self.callback:
self.callback(start, end)
def update(self):
data = self.im.get_array().ravel()[::10]
ax = self.hist_ax
for patch in self.patches:
ax.patches.remove(patch)
n, bins, self.patches = ax.hist(data, color='w', alpha=0.5,
bins=100, histtype='stepfilled', orientation="horizontal")
clim = self.im.get_clim()
ax.set_ylim(clim[0], clim[1])
ax.set_xlim(0, max(n) * 1.1)
if self.span:
self.onselect(*self.span)
class ImageViewer(object):
def __init__(self, img, ax=None, show_hist=True):
self.img = img
if not ax:
self.fig, self.ax = plt.subplots(1, 1)
else:
self.ax = ax
self.fig = ax.figure
self.ax.format_coord = self.format_coord
self._show_hist = show_hist
self.hs = None
self.cax = None
self.reset()
@property
def show_hist(self):
return self._show_hist
@show_hist.setter
def show_hist(self, value):
self._show_hist = value
self.reset()
def reset(self):
self.ax.clear()
if self.hs:
self.figure.delaxes(self.hs.hist_ax)
self.figure.delaxes(self.cax)
self.hs = None
self.im = self.ax.imshow(self.img, cmap='gray',
interpolation='nearest')
divider = make_axes_locatable(self.ax)
if self.show_hist:
hist_ax = divider.append_axes("right", size="10%", pad=0.03)
self.hs = HistSelector(hist_ax, self.im)
self.cax = divider.append_axes("right", size="5%", pad=0.03)
cb = plt.colorbar(self.im, cax=self.cax)
def format_coord(self, x, y):
numrows, numcols = self.img.shape
col = int(x+0.5)
row = int(y+0.5)
if col>=0 and col<numcols and row>=0 and row<numrows:
z = self.img[row,col]
if self.img.dtype.kind == 'f':
return 'x=%d, y=%d, z=%1.4f' % (x, y, z)
else:
return 'x=%d, y=%d, z=%d' % (x, y, z)
else:
return 'x=%df, y=%d'%(x, y)
def imshow(self, img, **kwargs):
self.im.set_array(img)
data = img[::10]
self.im.set_clim(data.min(), data.max())
self.hs.update()
self.fig.canvas.draw_idle()
# TODO: handle when image changes extent
if __name__ == '__main__':
from skimage.data import moon
def update():
iv.imshow(moon() + np.random.randint(10))
iv = ImageViewer(moon())
timer = iv.fig.canvas.new_timer(interval=100)
timer.add_callback(update)
timer.start()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment