Created
March 12, 2019 16:56
-
-
Save ycanerol/04e6a2dc64565b301c1dbd80372d4054 to your computer and use it in GitHub Desktop.
Multi-STA browser with slider
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
from matplotlib.widgets import Slider | |
import iofuncs as iof | |
import plotfuncs as plf | |
def stabrowser(stas, frame_duration=None, cmap=None, centerzero=True): | |
""" | |
Returns an interactive plot to browse multiple spatiotemporal | |
STAs at the same time. | |
Parameters | |
-------- | |
stas: | |
Numpy array containing STAs. First dimension should index individual cells, | |
last dimension should index time. | |
frame_duration: | |
Time between each frame. (optional) | |
cmap: | |
Colormap to use. | |
centerzero: | |
Whether to center the colormap around zero for diverging colormaps. | |
Example | |
------ | |
>>> print(stas.shape) # (nrcells, xpixels, ypixels, time) | |
(36, 75, 100, 40) | |
>>> fig, slider = stabrowser(stas, frame_duration=1/60) | |
Notes | |
----- | |
When calling the function, the slider is returned to prevent the reference | |
to it getting destroyed and to keep it interactive. | |
The dummy variable `_` can also be used. | |
""" | |
if not mpl.get_backend().startswith('Qt'): | |
raise ValueError('Switch to a Qt backend to see the animation.') | |
if cmap is None: | |
cmap = iof.config('colormap') | |
if centerzero: | |
vmax = np.nanmax(np.abs(stas)) | |
vmin = -vmax | |
else: | |
vmax, vmin = stas.max(), stas.min() | |
imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin) | |
rows, cols = plf.numsubplots(stas.shape[0]) | |
fig, axes = plt.subplots(rows, cols, sharex=True, sharey=True) | |
initial_frame = 5 | |
axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03]) | |
# For the slider to remain interactive, a reference to it should | |
# be kept, so it is returned by the function | |
slider_t = Slider(axsl, 'Frame before spike', | |
0, stas.shape[-1]-1, | |
valinit=initial_frame, | |
valstep=1, | |
valfmt='%2.0f') | |
def update(frame): | |
frame = int(frame) | |
for i in range(rows): | |
for j in range(cols): | |
im = axes[i, j].get_images()[0] | |
im.set_data(stas[i*rows+j, ..., frame]) | |
if frame_duration is not None: | |
fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms') | |
fig.canvas.draw_idle() | |
slider_t.on_changed(update) | |
for i in range(rows): | |
for j in range(cols): | |
ax = axes[i, j] | |
ax.imshow(stas[i*rows+j, ..., initial_frame], **imshowkwargs) | |
ax.set_axis_off() | |
plt.tight_layout() | |
plt.subplots_adjust(wspace=.01, hspace=.01) | |
return fig, slider_t | |
data = iof.load('20180802', 6) # Frozen noise | |
stas = np.array(data['stas']) | |
frame_duration = data['frame_duration'] | |
fig, _ = stabrowser(stas, frame_duration) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment