Last active
September 9, 2024 17:34
-
-
Save mhhennig/03fbccce6251f453bf0c53332a94ad45 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"/disk/scratch/mhennig/spikeinterface/HS2/testing/../herdingspikes/__init__.py 0.4.004+git.1a5d42e47ffc\n" | |
] | |
} | |
], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2\n", | |
"\n", | |
"import os, sys\n", | |
"sys.path.insert(0,\"../\")\n", | |
"sys.path.insert(0,\"../../spikeinterface_dev/src\")\n", | |
"\n", | |
"import spikeinterface.full as si\n", | |
"\n", | |
"from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods\n", | |
"from spikeinterface.sortingcomponents.peak_selection import select_peaks\n", | |
"from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods\n", | |
"from spikeinterface.sortingcomponents.motion import estimate_motion, InterpolateMotionRecording\n", | |
"from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline\n", | |
"\n", | |
"import herdingspikes as hs\n", | |
"import numpy as np\n", | |
"from matplotlib import pyplot as plt\n", | |
"si.set_global_job_kwargs(n_jobs=-1, progress_bar=False, chunk_duration=\"1s\")\n", | |
"\n", | |
"import timeit\n", | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"\n", | |
"import time\n", | |
"\n", | |
"print(hs.__file__, hs.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Generate data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"GENERATE_RECORDING = False\n", | |
"recording_name = 'simulated_neuropixels_recording'\n", | |
"times = {}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_units=200\n", | |
"if GENERATE_RECORDING:\n", | |
" rec, rec_drift, gt_sorting = si.generate_drifting_recording(\n", | |
" num_units=num_units,\n", | |
" duration=600.,\n", | |
" sampling_frequency=30000.0,\n", | |
" generate_probe_kwargs = dict(\n", | |
" num_columns=4,\n", | |
" num_contact_per_column=[96] * 4,\n", | |
" xpitch=16,\n", | |
" ypitch=40,\n", | |
" y_shift_per_column=[20, 0, 20, 0],\n", | |
" contact_shapes=\"square\",\n", | |
" contact_shape_params={\"width\": 12},\n", | |
" ),\n", | |
" generate_templates_kwargs=dict(\n", | |
" ms_before=1.5,\n", | |
" ms_after=3.0,\n", | |
" mode=\"ellipsoid\",\n", | |
" unit_params=dict(\n", | |
" alpha=(150.0, 500.0),\n", | |
" spatial_decay=(10, 45),\n", | |
" ),\n", | |
" ),\n", | |
" generate_unit_locations_kwargs=dict(\n", | |
" margin_um=10.0,\n", | |
" minimum_z=6.0,\n", | |
" maximum_z=25.0,\n", | |
" minimum_distance=12.0,\n", | |
" max_iteration=100,\n", | |
" distance_strict=False,\n", | |
" ), \n", | |
" generate_sorting_kwargs=dict(\n", | |
" firing_rates=(0.1, 4.0),\n", | |
" refractory_period_ms=4.0\n", | |
" ),\n", | |
" generate_noise_kwargs=dict(\n", | |
" noise_levels=(5.0, 10.0),\n", | |
" spatial_decay=25.0\n", | |
" ),\n", | |
" seed=42\n", | |
" )\n", | |
" rec.save_to_folder(\n", | |
" recording_name, folder=recording_name, overwrite=True\n", | |
" )\n", | |
" rec" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## SpikeInterface Detect+COM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def run_node():\n", | |
" job_kwargs = {}\n", | |
" detect_kwargs = {}\n", | |
" select_kwargs={}\n", | |
" localize_peaks_kwargs={}\n", | |
" estimate_motion_kwargs={}\n", | |
" interpolate_motion_kwargs={}\n", | |
"\n", | |
" # maybe do this directly in the folder when not None, but might be slow on external storage\n", | |
" gather_mode = \"memory\"\n", | |
" # node detect\n", | |
" method = detect_kwargs.pop(\"method\", \"locally_exclusive\")\n", | |
" method_class = detect_peak_methods[method]\n", | |
" node0 = method_class(rec_in_mem, **detect_kwargs)\n", | |
"\n", | |
" node1 = ExtractDenseWaveforms(rec_in_mem, parents=[node0], ms_before=0.3, ms_after=1.8)\n", | |
"\n", | |
" # node detect + localize\n", | |
" method = localize_peaks_kwargs.pop(\"method\", \"center_of_mass\")\n", | |
" method_class = localize_peak_methods[method]\n", | |
" node2 = method_class(\n", | |
" rec_in_mem, parents=[node0, node1], return_output=True, **localize_peaks_kwargs\n", | |
" )\n", | |
" pipeline_nodes = [node0, node1, node2]\n", | |
"\n", | |
" peaks, peak_locations = run_node_pipeline(\n", | |
" rec_in_mem,\n", | |
" pipeline_nodes,\n", | |
" job_kwargs,\n", | |
" job_name=\"detect and localize\",\n", | |
" gather_mode=gather_mode,\n", | |
" gather_kwargs=None,\n", | |
" squeeze_output=False,\n", | |
" folder=None,\n", | |
" names=None,\n", | |
" )\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if 'rec_in_mem' in globals():\n", | |
" del(rec_in_mem)\n", | |
"rec_in_mem = si.load_extractor(recording_name);\n", | |
"t0 = time.perf_counter()\n", | |
"run_node()\n", | |
"t1 = time.perf_counter()\n", | |
"\n", | |
"times['SI detect+COM'] = t1-t0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Herdingspikes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"parameters = {\n", | |
" \"chunk_size\": 32000,\n", | |
" \"lowpass\": True,\n", | |
" \"rescale\": True,\n", | |
" \"rescale_value\": -1280.0,\n", | |
" \"common_reference\": \"average\",\n", | |
" # \"common_reference\": \"median\",\n", | |
" \"spike_duration\": 1.0,\n", | |
" \"amp_avg_duration\": 0.4,\n", | |
" \"threshold\": 8.0,\n", | |
" \"min_avg_amp\": 1.0,\n", | |
" \"AHP_thr\": 0.0,\n", | |
" \"neighbor_radius\": 90.0,\n", | |
" \"inner_radius\": 70.0,\n", | |
" \"peak_jitter\": 0.25,\n", | |
" \"rise_duration\": 0.26,\n", | |
" \"decay_filtering\": False,\n", | |
" \"decay_ratio\": 1.0,\n", | |
" \"localize\": True,\n", | |
" \"save_shape\": True,\n", | |
" \"out_file\": \"HS2_detected\",\n", | |
" \"left_cutout_time\": 0.3,\n", | |
" \"right_cutout_time\": 1.8,\n", | |
" \"verbose\": False,\n", | |
" \"clustering_bandwidth\": 4.0,\n", | |
" \"clustering_alpha\": 4.5,\n", | |
" \"clustering_n_jobs\": -1,\n", | |
" \"clustering_bin_seeding\": True,\n", | |
" \"clustering_min_bin_freq\": 4,\n", | |
" \"clustering_subset\": None,\n", | |
" \"pca_ncomponents\": 2,\n", | |
" \"pca_whiten\": True,\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if 'rec_in_mem' in globals():\n", | |
" del(rec_in_mem)\n", | |
"rec_in_mem = si.load_extractor(recording_name);\n", | |
"t0 = time.perf_counter()\n", | |
"det = hs.HSDetectionLightning(rec_in_mem, parameters)\n", | |
"det.DetectFromRaw()\n", | |
"t1 = time.perf_counter()\n", | |
"\n", | |
"times['herdingspikes detect'] = t1-t0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Reading spikes from detection\n", | |
"Fitting dimensionality reduction using all spikes...\n", | |
"...projecting...\n", | |
"...done\n", | |
"Clustering...\n", | |
"Clustering 211677 spikes...\n", | |
"requested -1 cpus\n", | |
"using 8 cpus\n", | |
"number of seeds: 1912\n", | |
"seeds/job: 240\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.\n", | |
"[Parallel(n_jobs=8)]: Done 3 out of 8 | elapsed: 3.1s remaining: 5.1s\n", | |
"[Parallel(n_jobs=8)]: Done 8 out of 8 | elapsed: 3.3s finished\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of estimated units: 190\n" | |
] | |
} | |
], | |
"source": [ | |
"t0 = time.perf_counter()\n", | |
"Clusters = hs.HSClustering(det)\n", | |
"Clusters.ShapePCA()\n", | |
"Clusters.CombinedClustering(\n", | |
" alpha=parameters[\"clustering_alpha\"],\n", | |
" bandwidth=parameters[\"clustering_bandwidth\"],\n", | |
" bin_seeding=parameters[\"clustering_bin_seeding\"],\n", | |
" min_bin_freq=parameters[\"clustering_min_bin_freq\"],\n", | |
" cluster_subset=parameters[\"clustering_subset\"],\n", | |
" n_jobs=-1)\n", | |
"t1 = time.perf_counter()\n", | |
"times['herdingspikes sort'] = t1-t0\n", | |
"times['herdingspikes all'] = times['herdingspikes detect']+times['herdingspikes sort']" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Kilosort4" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/tmp/ipykernel_179427/2912753937.py:7: DeprecationWarning: `output_folder` is deprecated and will be removed in version 0.103.0 Please use folder instead\n", | |
" ks_sorting = si.run_sorter(\n", | |
"INFO:kilosort.io:========================================\n", | |
"INFO:kilosort.io:Loading recording with SpikeInterface...\n", | |
"INFO:kilosort.io:number of samples: 18000000\n", | |
"INFO:kilosort.io:number of channels: 384\n", | |
"INFO:kilosort.io:numbef of segments: 1\n", | |
"INFO:kilosort.io:sampling rate: 30000.0\n", | |
"INFO:kilosort.io:dtype: float32\n", | |
"INFO:kilosort.io:========================================\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Computing preprocessing variables.\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"INFO:kilosort.run_kilosort:N samples: 18000000\n", | |
"INFO:kilosort.run_kilosort:N seconds: 600.0\n", | |
"INFO:kilosort.run_kilosort:N batches: 300\n", | |
"INFO:kilosort.run_kilosort:Preprocessing filters computed in 2.11s; total 2.11s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Computing drift correction.\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"INFO:kilosort.datashift:nblocks = 0, skipping drift correction\n", | |
"INFO:kilosort.run_kilosort:drift computed in 0.00s; total 2.17s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Extracting spikes using templates\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"INFO:kilosort.spikedetect:Re-computing universal templates from data.\n", | |
"100%|██████████| 300/300 [05:51<00:00, 1.17s/it]\n", | |
"INFO:kilosort.run_kilosort:217572 spikes extracted in 354.78s; total 356.95s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:First clustering\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"100%|██████████| 96/96 [00:32<00:00, 3.00it/s]\n", | |
"INFO:kilosort.run_kilosort:320 clusters found, in 32.55s; total 389.50s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Extracting spikes using cluster waveforms\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"100%|██████████| 300/300 [00:46<00:00, 6.39it/s]\n", | |
"INFO:kilosort.run_kilosort:235646 spikes extracted in 47.07s; total 436.57s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Final clustering\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"100%|██████████| 96/96 [00:26<00:00, 3.66it/s]\n", | |
"INFO:kilosort.run_kilosort:187 clusters found, in 26.22s; total 462.79s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Merging clusters\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"INFO:kilosort.run_kilosort:177 units found, in 0.23s; total 463.03s\n", | |
"INFO:kilosort.run_kilosort: \n", | |
"INFO:kilosort.run_kilosort:Saving to phy and computing refractory periods\n", | |
"INFO:kilosort.run_kilosort:----------------------------------------\n", | |
"INFO:kilosort.run_kilosort:170 units found with good refractory periods\n", | |
"INFO:kilosort.run_kilosort:Total runtime: 464.79s = 00:07:45 h:m:s\n", | |
"INFO:kilosort.run_kilosort:Sorting output saved in: /disk/scratch/mhennig/spikeinterface/HS2/testing/KS4_output/sorter_output.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"kilosort4 run time 465.66s\n" | |
] | |
} | |
], | |
"source": [ | |
"if 'rec_in_mem' in globals():\n", | |
" del(rec_in_mem)\n", | |
"params = si.Kilosort4Sorter.default_params()\n", | |
"params.update({'nblocks': 0})\n", | |
"rec_in_mem = si.load_extractor(recording_name);\n", | |
"t0 = time.perf_counter()\n", | |
"ks_sorting = si.run_sorter(\n", | |
" sorter_name=\"kilosort4\",\n", | |
" recording=rec_in_mem,\n", | |
" output_folder=\"KS4_output\",\n", | |
" remove_existing_folder=True,\n", | |
" singularity_image=False,\n", | |
" verbose=True,\n", | |
" delete_container_files=False,\n", | |
" delete_output_folder=True,\n", | |
" **params\n", | |
")\n", | |
"t1 = time.perf_counter()\n", | |
"times['Kilosort4 all'] = t1-t0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'SI detect+COM': 26.24065605085343,\n", | |
" 'herdingspikes detect': 15.178689972963184,\n", | |
" 'herdingspikes sort': 5.46234157984145,\n", | |
" 'herdingspikes all': 20.641031552804634,\n", | |
" 'Kilosort4 all': 468.0466902530752}" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"times" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"sns.barplot(data=times);\n", | |
"plt.ylabel(\"Execution time (sec/600s data)\");\n", | |
"plt.xticks(rotation=30);" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:si]", | |
"language": "python", | |
"name": "conda-env-si-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment