Created
May 8, 2026 14:50
-
-
Save cwindolf/7e3b08e428a6b7e30ed8024b9166e9c2 to your computer and use it in GitHub Desktop.
Spike feature scatterplot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import dartsort.vis as dartvis | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| def scatterplot_time_snip_spikes( | |
| t0: float, | |
| times_s: np.ndarray, | |
| amplitudes: np.ndarray, | |
| x_um: np.ndarray, | |
| depths_um: np.ndarray, | |
| labels: np.ndarray, | |
| entropy: np.ndarray, | |
| geom: np.ndarray, | |
| dt: float = 10.0, | |
| depth_lim: tuple[float, float] = (2500.0, 3000.0), | |
| amp_lim: tuple[float, float] | None = None, | |
| x_lim: tuple[float, float] | None = None, | |
| entropy_lim: tuple[float, float] = (0.0, 1.25), | |
| time_tick_interval: int = 5, | |
| ): | |
| """Scatter spike features in a small time snip""" | |
| fig, axes = plt.subplots( | |
| nrows=1, | |
| ncols=4, | |
| sharey=True, | |
| gridspec_kw=dict( | |
| width_ratios=[1, 1, 1, 2], | |
| hspace=0.0, | |
| wspace=0.0, | |
| ), | |
| layout='constrained', | |
| figsize=(8, 5), | |
| ) | |
| feature_cbar_ax = axes[2].inset_axes([0.7, 0.04, 0.05, 0.15]) | |
| # indices of spikes in the time chunk | |
| to_show_rel = np.flatnonzero(times_s == times_s.clip(t0, t0 + dt)) | |
| to_show = to_show_rel | |
| dartvis.scatter_spike_features( | |
| axes=axes[[0, 1, 3]], | |
| # sorting=st, | |
| labels=labels, | |
| amplitudes=amplitudes, | |
| depths_um=depths_um, | |
| x=x_um, | |
| times_s=times_s, | |
| to_show=to_show, | |
| show_geom=False, | |
| s=3, | |
| ) | |
| hcs = entropy[to_show_rel].clip(*entropy_lim) | |
| cii = np.argsort(hcs) | |
| cjj = to_show_rel[cii] | |
| axes[2].scatter(amplitudes[cjj], depths_um[cjj], c=plt.cm.viridis(hcs[cii]), alpha=0.5, s=3, lw=0) | |
| axes[2].semilogx() | |
| h0, h1 = entropy_lim | |
| entropy_scm = plt.cm.ScalarMappable(norm=plt.Normalize(vmin=h0, vmax=h1), cmap="viridis") | |
| cbar = plt.colorbar(cax=feature_cbar_ax, mappable=entropy_scm) | |
| cbar.set_label('entropy (nats)', size='x-small') | |
| feature_cbar_ax.set_yticks([0, 1]) | |
| feature_cbar_ax.tick_params(labelsize='x-small') | |
| dartvis.recanim.draw_probe(axes[0], geom) | |
| axes[0].set_ylim(depth_lim) | |
| axes[0].set_ylabel('depth (μm)') | |
| if x_lim is not None: | |
| axes[0].set_xlim(x_lim) | |
| if amp_lim is not None: | |
| axes[1].set_xlim(amp_lim) | |
| axes[2].set_xlim(amp_lim) | |
| axes[2].set_xlabel('amplitude (su)') | |
| # time xticks | |
| xt0 = time_tick_interval * int(np.ceil((t0 + 1) / time_tick_interval)) | |
| xt0 = xt0 + (xt0 % 2) | |
| xt = np.arange(xt0, t0 + dt - 1, time_tick_interval) | |
| axes[3].set_xticks(xt) | |
| return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment