Skip to content

Instantly share code, notes, and snippets.

@cwindolf
Created May 8, 2026 14:50
Show Gist options
  • Select an option

  • Save cwindolf/7e3b08e428a6b7e30ed8024b9166e9c2 to your computer and use it in GitHub Desktop.

Select an option

Save cwindolf/7e3b08e428a6b7e30ed8024b9166e9c2 to your computer and use it in GitHub Desktop.
Spike feature scatterplot
import dartsort.vis as dartvis
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_time_snip_spikes(
t0: float,
times_s: np.ndarray,
amplitudes: np.ndarray,
x_um: np.ndarray,
depths_um: np.ndarray,
labels: np.ndarray,
entropy: np.ndarray,
geom: np.ndarray,
dt: float = 10.0,
depth_lim: tuple[float, float] = (2500.0, 3000.0),
amp_lim: tuple[float, float] | None = None,
x_lim: tuple[float, float] | None = None,
entropy_lim: tuple[float, float] = (0.0, 1.25),
time_tick_interval: int = 5,
):
"""Scatter spike features in a small time snip"""
fig, axes = plt.subplots(
nrows=1,
ncols=4,
sharey=True,
gridspec_kw=dict(
width_ratios=[1, 1, 1, 2],
hspace=0.0,
wspace=0.0,
),
layout='constrained',
figsize=(8, 5),
)
feature_cbar_ax = axes[2].inset_axes([0.7, 0.04, 0.05, 0.15])
# indices of spikes in the time chunk
to_show_rel = np.flatnonzero(times_s == times_s.clip(t0, t0 + dt))
to_show = to_show_rel
dartvis.scatter_spike_features(
axes=axes[[0, 1, 3]],
# sorting=st,
labels=labels,
amplitudes=amplitudes,
depths_um=depths_um,
x=x_um,
times_s=times_s,
to_show=to_show,
show_geom=False,
s=3,
)
hcs = entropy[to_show_rel].clip(*entropy_lim)
cii = np.argsort(hcs)
cjj = to_show_rel[cii]
axes[2].scatter(amplitudes[cjj], depths_um[cjj], c=plt.cm.viridis(hcs[cii]), alpha=0.5, s=3, lw=0)
axes[2].semilogx()
h0, h1 = entropy_lim
entropy_scm = plt.cm.ScalarMappable(norm=plt.Normalize(vmin=h0, vmax=h1), cmap="viridis")
cbar = plt.colorbar(cax=feature_cbar_ax, mappable=entropy_scm)
cbar.set_label('entropy (nats)', size='x-small')
feature_cbar_ax.set_yticks([0, 1])
feature_cbar_ax.tick_params(labelsize='x-small')
dartvis.recanim.draw_probe(axes[0], geom)
axes[0].set_ylim(depth_lim)
axes[0].set_ylabel('depth (μm)')
if x_lim is not None:
axes[0].set_xlim(x_lim)
if amp_lim is not None:
axes[1].set_xlim(amp_lim)
axes[2].set_xlim(amp_lim)
axes[2].set_xlabel('amplitude (su)')
# time xticks
xt0 = time_tick_interval * int(np.ceil((t0 + 1) / time_tick_interval))
xt0 = xt0 + (xt0 % 2)
xt = np.arange(xt0, t0 + dt - 1, time_tick_interval)
axes[3].set_xticks(xt)
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment