Skip to content

Instantly share code, notes, and snippets.

@davesque
Created November 1, 2017 16:54
Show Gist options
  • Save davesque/84937fcb0f8cbc1103f35f9d4923a630 to your computer and use it in GitHub Desktop.
Save davesque/84937fcb0f8cbc1103f35f9d4923a630 to your computer and use it in GitHub Desktop.
Identifying keypoints with dilation
#!/usr/bin/env python
import os
from matplotlib.widgets import Slider, Button
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.morphology import grey_dilation
import django
import matplotlib.pyplot as plt
import numpy as np
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'passport.settings_local')
django.setup()
from passport.numpy_utils import plot_multiple, grid_interpolate
factor = 4
default_base = 1.01
default_sigma = 0.75
default_dilate = 2
# Calculate images
raw = np.loadtxt('dan.csv', delimiter=',')
delogged = default_base ** raw
interpolated = grid_interpolate(delogged, factor=factor)
gaussian = gaussian_filter(interpolated, sigma=default_sigma * factor, mode='nearest')
dil = int(round(default_dilate * factor))
I, J = np.mgrid[-dil:dil + 1, -dil:dil + 1]
footprint = np.sqrt(I ** 2 + J ** 2) <= dil
dilated = grey_dilation(gaussian, footprint=footprint)
# Calculate points
raw_res = np.array(raw.shape).astype('float')
interp_res = np.array(interpolated.shape).astype('float')
clip = np.array([1., 1.])
interp_to_raw = ((raw_res - clip) / (interp_res - clip))[..., ::-1].reshape(2, 1)
points = np.array(np.where(gaussian == dilated))[::-1, ...]
raw_points = points * interp_to_raw
# Draw plots
fig, ax = plt.subplots()
plots = plot_multiple([
('Raw', raw, raw_points, None),
('De-logged', delogged, raw_points, None),
('Interpolated', interpolated, points, None),
('Gaussian', gaussian, points, None),
('Dilated', dilated, points, None),
])
# Define widgets
axcolor = 'lightgoldenrodyellow'
axbase = plt.axes([.6, .2, .3, .03], facecolor=axcolor)
axdilate = plt.axes([.6, .15, .3, .03], facecolor=axcolor)
axsigma = plt.axes([.6, .1, .3, .03], facecolor=axcolor)
axreset = plt.axes([0.8, 0.025, 0.1, 0.04])
base = Slider(axbase, 'Log base', 1, 1.03, valinit=default_base)
dilate = Slider(axdilate, 'Dilation', 0, 5, valinit=default_dilate)
sigma = Slider(axsigma, 'Sigma', 0, 5, valinit=default_sigma)
button = Button(axreset, 'Reset', color=axcolor, hovercolor='0.975')
# Event handlers
def update(val):
dil = int(round(dilate.val * factor))
sig = sigma.val * factor
bas = base.val
# Update image plots
delogged = bas ** raw
interpolated = grid_interpolate(delogged, factor=factor)
gaussian = gaussian_filter(interpolated, sigma=sig, mode='nearest')
I, J = np.mgrid[-dil:dil + 1, -dil:dil + 1]
footprint = np.sqrt(I ** 2 + J ** 2) <= dil
dilated = grey_dilation(gaussian, footprint=footprint)
plots[2].set_data(delogged)
plots[4].set_data(interpolated)
plots[6].set_data(gaussian)
plots[8].set_data(dilated)
# Update point plots
points = np.array(np.where(gaussian == dilated))[::-1, ...]
raw_points = points * interp_to_raw
plots[1][0].set_data(raw_points[0], raw_points[1])
plots[3][0].set_data(raw_points[0], raw_points[1])
plots[5][0].set_data(points[0], points[1])
plots[7][0].set_data(points[0], points[1])
plots[9][0].set_data(points[0], points[1])
fig.canvas.draw_idle()
base.on_changed(update)
dilate.on_changed(update)
sigma.on_changed(update)
def reset(event):
base.reset()
dilate.reset()
sigma.reset()
button.on_clicked(reset)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment