Skip to content

Instantly share code, notes, and snippets.

@mhhennig
Last active September 9, 2024 17:34
Show Gist options
  • Save mhhennig/03fbccce6251f453bf0c53332a94ad45 to your computer and use it in GitHub Desktop.
Save mhhennig/03fbccce6251f453bf0c53332a94ad45 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/disk/scratch/mhennig/spikeinterface/HS2/testing/../herdingspikes/__init__.py 0.4.004+git.1a5d42e47ffc\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os, sys\n",
"sys.path.insert(0,\"../\")\n",
"sys.path.insert(0,\"../../spikeinterface_dev/src\")\n",
"\n",
"import spikeinterface.full as si\n",
"\n",
"from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods\n",
"from spikeinterface.sortingcomponents.peak_selection import select_peaks\n",
"from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods\n",
"from spikeinterface.sortingcomponents.motion import estimate_motion, InterpolateMotionRecording\n",
"from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline\n",
"\n",
"import herdingspikes as hs\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"si.set_global_job_kwargs(n_jobs=-1, progress_bar=False, chunk_duration=\"1s\")\n",
"\n",
"import timeit\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"import time\n",
"\n",
"print(hs.__file__, hs.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"GENERATE_RECORDING = False\n",
"recording_name = 'simulated_neuropixels_recording'\n",
"times = {}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"num_units=200\n",
"if GENERATE_RECORDING:\n",
" rec, rec_drift, gt_sorting = si.generate_drifting_recording(\n",
" num_units=num_units,\n",
" duration=600.,\n",
" sampling_frequency=30000.0,\n",
" generate_probe_kwargs = dict(\n",
" num_columns=4,\n",
" num_contact_per_column=[96] * 4,\n",
" xpitch=16,\n",
" ypitch=40,\n",
" y_shift_per_column=[20, 0, 20, 0],\n",
" contact_shapes=\"square\",\n",
" contact_shape_params={\"width\": 12},\n",
" ),\n",
" generate_templates_kwargs=dict(\n",
" ms_before=1.5,\n",
" ms_after=3.0,\n",
" mode=\"ellipsoid\",\n",
" unit_params=dict(\n",
" alpha=(150.0, 500.0),\n",
" spatial_decay=(10, 45),\n",
" ),\n",
" ),\n",
" generate_unit_locations_kwargs=dict(\n",
" margin_um=10.0,\n",
" minimum_z=6.0,\n",
" maximum_z=25.0,\n",
" minimum_distance=12.0,\n",
" max_iteration=100,\n",
" distance_strict=False,\n",
" ), \n",
" generate_sorting_kwargs=dict(\n",
" firing_rates=(0.1, 4.0),\n",
" refractory_period_ms=4.0\n",
" ),\n",
" generate_noise_kwargs=dict(\n",
" noise_levels=(5.0, 10.0),\n",
" spatial_decay=25.0\n",
" ),\n",
" seed=42\n",
" )\n",
" rec.save_to_folder(\n",
" recording_name, folder=recording_name, overwrite=True\n",
" )\n",
" rec"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SpikeInterface Detect+COM"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def run_node():\n",
" job_kwargs = {}\n",
" detect_kwargs = {}\n",
" select_kwargs={}\n",
" localize_peaks_kwargs={}\n",
" estimate_motion_kwargs={}\n",
" interpolate_motion_kwargs={}\n",
"\n",
" # maybe do this directly in the folder when not None, but might be slow on external storage\n",
" gather_mode = \"memory\"\n",
" # node detect\n",
" method = detect_kwargs.pop(\"method\", \"locally_exclusive\")\n",
" method_class = detect_peak_methods[method]\n",
" node0 = method_class(rec_in_mem, **detect_kwargs)\n",
"\n",
" node1 = ExtractDenseWaveforms(rec_in_mem, parents=[node0], ms_before=0.3, ms_after=1.8)\n",
"\n",
" # node detect + localize\n",
" method = localize_peaks_kwargs.pop(\"method\", \"center_of_mass\")\n",
" method_class = localize_peak_methods[method]\n",
" node2 = method_class(\n",
" rec_in_mem, parents=[node0, node1], return_output=True, **localize_peaks_kwargs\n",
" )\n",
" pipeline_nodes = [node0, node1, node2]\n",
"\n",
" peaks, peak_locations = run_node_pipeline(\n",
" rec_in_mem,\n",
" pipeline_nodes,\n",
" job_kwargs,\n",
" job_name=\"detect and localize\",\n",
" gather_mode=gather_mode,\n",
" gather_kwargs=None,\n",
" squeeze_output=False,\n",
" folder=None,\n",
" names=None,\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"if 'rec_in_mem' in globals():\n",
" del(rec_in_mem)\n",
"rec_in_mem = si.load_extractor(recording_name);\n",
"t0 = time.perf_counter()\n",
"run_node()\n",
"t1 = time.perf_counter()\n",
"\n",
"times['SI detect+COM'] = t1-t0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Herdingspikes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"parameters = {\n",
" \"chunk_size\": 32000,\n",
" \"lowpass\": True,\n",
" \"rescale\": True,\n",
" \"rescale_value\": -1280.0,\n",
" \"common_reference\": \"average\",\n",
" # \"common_reference\": \"median\",\n",
" \"spike_duration\": 1.0,\n",
" \"amp_avg_duration\": 0.4,\n",
" \"threshold\": 8.0,\n",
" \"min_avg_amp\": 1.0,\n",
" \"AHP_thr\": 0.0,\n",
" \"neighbor_radius\": 90.0,\n",
" \"inner_radius\": 70.0,\n",
" \"peak_jitter\": 0.25,\n",
" \"rise_duration\": 0.26,\n",
" \"decay_filtering\": False,\n",
" \"decay_ratio\": 1.0,\n",
" \"localize\": True,\n",
" \"save_shape\": True,\n",
" \"out_file\": \"HS2_detected\",\n",
" \"left_cutout_time\": 0.3,\n",
" \"right_cutout_time\": 1.8,\n",
" \"verbose\": False,\n",
" \"clustering_bandwidth\": 4.0,\n",
" \"clustering_alpha\": 4.5,\n",
" \"clustering_n_jobs\": -1,\n",
" \"clustering_bin_seeding\": True,\n",
" \"clustering_min_bin_freq\": 4,\n",
" \"clustering_subset\": None,\n",
" \"pca_ncomponents\": 2,\n",
" \"pca_whiten\": True,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"if 'rec_in_mem' in globals():\n",
" del(rec_in_mem)\n",
"rec_in_mem = si.load_extractor(recording_name);\n",
"t0 = time.perf_counter()\n",
"det = hs.HSDetectionLightning(rec_in_mem, parameters)\n",
"det.DetectFromRaw()\n",
"t1 = time.perf_counter()\n",
"\n",
"times['herdingspikes detect'] = t1-t0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading spikes from detection\n",
"Fitting dimensionality reduction using all spikes...\n",
"...projecting...\n",
"...done\n",
"Clustering...\n",
"Clustering 211677 spikes...\n",
"requested -1 cpus\n",
"using 8 cpus\n",
"number of seeds: 1912\n",
"seeds/job: 240\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.\n",
"[Parallel(n_jobs=8)]: Done 3 out of 8 | elapsed: 3.1s remaining: 5.1s\n",
"[Parallel(n_jobs=8)]: Done 8 out of 8 | elapsed: 3.3s finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of estimated units: 190\n"
]
}
],
"source": [
"t0 = time.perf_counter()\n",
"Clusters = hs.HSClustering(det)\n",
"Clusters.ShapePCA()\n",
"Clusters.CombinedClustering(\n",
" alpha=parameters[\"clustering_alpha\"],\n",
" bandwidth=parameters[\"clustering_bandwidth\"],\n",
" bin_seeding=parameters[\"clustering_bin_seeding\"],\n",
" min_bin_freq=parameters[\"clustering_min_bin_freq\"],\n",
" cluster_subset=parameters[\"clustering_subset\"],\n",
" n_jobs=-1)\n",
"t1 = time.perf_counter()\n",
"times['herdingspikes sort'] = t1-t0\n",
"times['herdingspikes all'] = times['herdingspikes detect']+times['herdingspikes sort']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kilosort4"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_179427/2912753937.py:7: DeprecationWarning: `output_folder` is deprecated and will be removed in version 0.103.0 Please use folder instead\n",
" ks_sorting = si.run_sorter(\n",
"INFO:kilosort.io:========================================\n",
"INFO:kilosort.io:Loading recording with SpikeInterface...\n",
"INFO:kilosort.io:number of samples: 18000000\n",
"INFO:kilosort.io:number of channels: 384\n",
"INFO:kilosort.io:numbef of segments: 1\n",
"INFO:kilosort.io:sampling rate: 30000.0\n",
"INFO:kilosort.io:dtype: float32\n",
"INFO:kilosort.io:========================================\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Computing preprocessing variables.\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"INFO:kilosort.run_kilosort:N samples: 18000000\n",
"INFO:kilosort.run_kilosort:N seconds: 600.0\n",
"INFO:kilosort.run_kilosort:N batches: 300\n",
"INFO:kilosort.run_kilosort:Preprocessing filters computed in 2.11s; total 2.11s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Computing drift correction.\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"INFO:kilosort.datashift:nblocks = 0, skipping drift correction\n",
"INFO:kilosort.run_kilosort:drift computed in 0.00s; total 2.17s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Extracting spikes using templates\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"INFO:kilosort.spikedetect:Re-computing universal templates from data.\n",
"100%|██████████| 300/300 [05:51<00:00, 1.17s/it]\n",
"INFO:kilosort.run_kilosort:217572 spikes extracted in 354.78s; total 356.95s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:First clustering\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"100%|██████████| 96/96 [00:32<00:00, 3.00it/s]\n",
"INFO:kilosort.run_kilosort:320 clusters found, in 32.55s; total 389.50s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Extracting spikes using cluster waveforms\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"100%|██████████| 300/300 [00:46<00:00, 6.39it/s]\n",
"INFO:kilosort.run_kilosort:235646 spikes extracted in 47.07s; total 436.57s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Final clustering\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"100%|██████████| 96/96 [00:26<00:00, 3.66it/s]\n",
"INFO:kilosort.run_kilosort:187 clusters found, in 26.22s; total 462.79s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Merging clusters\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"INFO:kilosort.run_kilosort:177 units found, in 0.23s; total 463.03s\n",
"INFO:kilosort.run_kilosort: \n",
"INFO:kilosort.run_kilosort:Saving to phy and computing refractory periods\n",
"INFO:kilosort.run_kilosort:----------------------------------------\n",
"INFO:kilosort.run_kilosort:170 units found with good refractory periods\n",
"INFO:kilosort.run_kilosort:Total runtime: 464.79s = 00:07:45 h:m:s\n",
"INFO:kilosort.run_kilosort:Sorting output saved in: /disk/scratch/mhennig/spikeinterface/HS2/testing/KS4_output/sorter_output.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"kilosort4 run time 465.66s\n"
]
}
],
"source": [
"if 'rec_in_mem' in globals():\n",
" del(rec_in_mem)\n",
"params = si.Kilosort4Sorter.default_params()\n",
"params.update({'nblocks': 0})\n",
"rec_in_mem = si.load_extractor(recording_name);\n",
"t0 = time.perf_counter()\n",
"ks_sorting = si.run_sorter(\n",
" sorter_name=\"kilosort4\",\n",
" recording=rec_in_mem,\n",
" output_folder=\"KS4_output\",\n",
" remove_existing_folder=True,\n",
" singularity_image=False,\n",
" verbose=True,\n",
" delete_container_files=False,\n",
" delete_output_folder=True,\n",
" **params\n",
")\n",
"t1 = time.perf_counter()\n",
"times['Kilosort4 all'] = t1-t0"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'SI detect+COM': 26.24065605085343,\n",
" 'herdingspikes detect': 15.178689972963184,\n",
" 'herdingspikes sort': 5.46234157984145,\n",
" 'herdingspikes all': 20.641031552804634,\n",
" 'Kilosort4 all': 468.0466902530752}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"times"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.barplot(data=times);\n",
"plt.ylabel(\"Execution time (sec/600s data)\");\n",
"plt.xticks(rotation=30);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:si]",
"language": "python",
"name": "conda-env-si-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment