Skip to content

Instantly share code, notes, and snippets.

@mhhennig
Last active August 1, 2024 09:06
Show Gist options
  • Save mhhennig/4c391b2e2c8b338d573fc951b1950b5d to your computer and use it in GitHub Desktop.
Save mhhennig/4c391b2e2c8b338d573fc951b1950b5d to your computer and use it in GitHub Desktop.
Comparison of HS detection methods
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"import os, sys\n",
"sys.path.insert(0,\"../\")\n",
"sys.path.insert(0,\"../../spikeinterface/src\")\n",
"sys.path.insert(0,\"../../spikeinterface/spikeinterface/src\")\n",
"\n",
"os.environ[\"OMP_DISPLAY_ENV\"] = \"true\"\n",
"os.environ[\"KMP_VERSION\"] = \"true\"\n",
"\n",
"import spikeinterface.full as si\n",
"import herdingspikes as hs\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"si.set_global_job_kwargs(n_jobs=-1, progress_bar=False, chunk_duration=\"1s\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"rec_name = \"simulated_short_recording\"\n",
"\n",
"if not os.path.exists(rec_name):\n",
" num_units = 20\n",
" rec, _, gt_sorting = si.generate_drifting_recording(\n",
" num_units=num_units,\n",
" duration=100.0,\n",
" sampling_frequency=20000.0,\n",
" generate_probe_kwargs=dict(\n",
" num_columns=8,\n",
" num_contact_per_column=[8] * 8,\n",
" xpitch=42,\n",
" ypitch=42,\n",
" contact_shapes=\"square\",\n",
" contact_shape_params={\"width\": 42},\n",
" ),\n",
" generate_templates_kwargs=dict(\n",
" ms_before=1.5,\n",
" ms_after=3.0,\n",
" mode=\"ellipsoid\",\n",
" unit_params=dict(\n",
" alpha=(150.0, 500.0),\n",
" spatial_decay=(10, 45),\n",
" ),\n",
" ),\n",
" generate_unit_locations_kwargs=dict(\n",
" margin_um=10.0,\n",
" minimum_z=6.0,\n",
" maximum_z=25.0,\n",
" minimum_distance=12.0,\n",
" max_iteration=100,\n",
" distance_strict=False,\n",
" ),\n",
" generate_sorting_kwargs=dict(\n",
" # firing_rates=(1.0, 4.0),\n",
" firing_rates=(1.0, 10.0),\n",
" refractory_period_ms=4.0,\n",
" ),\n",
" generate_noise_kwargs=dict(noise_levels=(4.0, 10.0), spatial_decay=25.0),\n",
" seed=42,\n",
" )\n",
" rec.save_to_folder(rec_name, folder=rec_name, overwrite=True)\n",
"\n",
"rec = si.load_extractor(rec_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Parameters for legacy"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"params_legacy = {\n",
" # core params\n",
" \"clustering_bandwidth\": 5.5,\n",
" \"clustering_alpha\": 5.5,\n",
" \"clustering_n_jobs\": -1,\n",
" \"clustering_bin_seeding\": True,\n",
" \"clustering_min_bin_freq\": 16,\n",
" \"clustering_subset\": None,\n",
" \"left_cutout_time\": 0.3,\n",
" \"right_cutout_time\": 1.8,\n",
" \"detect_threshold\": 8,\n",
" # extra probe params\n",
" \"probe_masked_channels\": [],\n",
" \"probe_inner_radius\": 70,\n",
" \"probe_neighbor_radius\": 90,\n",
" \"probe_event_length\": 0.26,\n",
" \"probe_peak_jitter\": 0.25,\n",
" # extra detection params\n",
" \"t_inc\": 100000,\n",
" \"num_com_centers\": 1,\n",
" \"maa\": 4.0,\n",
" \"ahpthr\": 0.0, # this is not working correctly\n",
" \"out_file_name\": \"HS2_detected\",\n",
" \"decay_filtering\": False,\n",
" \"save_all\": False,\n",
" \"amp_evaluation_time\": 0.4,\n",
" \"spk_evaluation_time\": 1.0,\n",
" # extra pca params\n",
" \"pca_ncomponents\": 2,\n",
" \"pca_whiten\": True,\n",
" # remove duplicates (based on spk_evaluation_time)\n",
" \"filter_duplicates\": True,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Equivalent parameters for Lightning"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"params_lightning = {\n",
" \"chunk_size\": 500000,\n",
" \"rescale\": True,\n",
" \"lowpass\": False,\n",
" \"rescale_value\": -1280.0,\n",
" \"common_reference\": \"average\",\n",
" \"spike_duration\": 1.0,\n",
" \"amp_avg_duration\": 0.4,\n",
" \"threshold\": 8.0,\n",
" \"min_avg_amp\": 1.0, # = maa/4\n",
" \"AHP_thr\": 0.0,\n",
" \"neighbor_radius\": 90.0,\n",
" \"inner_radius\": 70.0,\n",
" \"peak_jitter\": 0.25,\n",
" \"rise_duration\": 0.26,\n",
" \"decay_filtering\": False,\n",
" \"decay_ratio\": 1.0,\n",
" \"localize\": True,\n",
" \"save_shape\": True,\n",
" \"out_file\": \"HS2_detected\",\n",
" \"left_cutout_time\": 0.3,\n",
" \"right_cutout_time\": 1.8,\n",
" \"verbose\": True,\n",
" \"clustering_bandwidth\": 4.0,\n",
" \"clustering_alpha\": 4.5,\n",
" \"clustering_n_jobs\": -1,\n",
" \"clustering_bin_seeding\": True,\n",
" \"clustering_min_bin_freq\": 4,\n",
" \"clustering_subset\": None,\n",
" \"pca_ncomponents\": 2,\n",
" \"pca_whiten\": True,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Generating new position and neighbor files from data file\n",
"# Not Masking any Channels\n",
"# Sampling rate: 20000\n",
"# Localization On\n",
"# Number of recorded channels: 64\n",
"# Analysing frames: 2000000; Seconds: 100.0\n",
"# Frames before spike in cutout: 6\n",
"# Frames after spike in cutout: 36\n",
"# tcuts: 26 56\n",
"# tInc: 100000\n",
"# Detection completed, time taken: 0:00:02.856515\n",
"# Time per frame: 0:00:00.001428\n",
"# Time per sample: 0:00:00.000022\n",
"Loaded 10692 spikes.\n"
]
}
],
"source": [
"# run legacy\n",
"recf = si.normalize_by_quantile(\n",
" recording=rec, scale=20., median=0.0, q1=0.025, q2=0.975\n",
" )\n",
"probe = hs.probe.RecordingExtractor(recf, \n",
" masked_channels=params_legacy[\"probe_masked_channels\"],\n",
" inner_radius=params_legacy[\"probe_inner_radius\"],\n",
" neighbor_radius=params_legacy[\"probe_neighbor_radius\"],\n",
" event_length=params_legacy[\"probe_event_length\"],\n",
" peak_jitter=params_legacy[\"probe_peak_jitter\"],\n",
" )\n",
"# probe.show(figwidth=2)\n",
"sorter_output_folder = 'HS2_results_leagcy'\n",
"det = hs.HSDetection(probe,\n",
" file_directory_name=str(sorter_output_folder),\n",
" left_cutout_time=params_legacy[\"left_cutout_time\"],\n",
" right_cutout_time=params_legacy[\"right_cutout_time\"],\n",
" threshold=params_legacy[\"detect_threshold\"],\n",
" to_localize=True,\n",
" num_com_centers=params_legacy[\"num_com_centers\"],\n",
" maa=params_legacy[\"maa\"],\n",
" ahpthr=params_legacy[\"ahpthr\"],\n",
" out_file_name=params_legacy[\"out_file_name\"],\n",
" decay_filtering=params_legacy[\"decay_filtering\"],\n",
" save_all=params_legacy[\"save_all\"],\n",
" amp_evaluation_time=params_legacy[\"amp_evaluation_time\"],\n",
" spk_evaluation_time=params_legacy[\"spk_evaluation_time\"],)\n",
"\n",
"det.DetectFromRaw(load=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HSDetection: Analysing segment 0, frames from 0 to 500000 (0.0%)\n",
"HSDetection: Analysing segment 0, frames from 500000 to 1000000 (25.0%)\n",
"HSDetection: Analysing segment 0, frames from 1000000 to 1500000 (50.0%)\n",
"HSDetection: Analysing segment 0, frames from 1500000 to 2000000 (75.0%)\n",
"writing spikes to HS2_detected.hdf5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"LLVM OMP version: 5.0.20140926\n",
"LLVM OMP library type: performance\n",
"LLVM OMP link type: dynamic\n",
"LLVM OMP build time: no_timestamp\n",
"LLVM OMP build compiler: Clang 12.0\n",
"LLVM OMP alternative compiler support: yes\n",
"LLVM OMP API version: 5.0 (201611)\n",
"LLVM OMP dynamic error checking: no\n",
"LLVM OMP thread affinity support: no\n",
"\n",
"OPENMP DISPLAY ENVIRONMENT BEGIN\n",
" _OPENMP='201611'\n",
" [host] OMP_AFFINITY_FORMAT='OMP: pid %P tid %i thread %n bound to OS proc set {%A}'\n",
" [host] OMP_ALLOCATOR='omp_default_mem_alloc'\n",
" [host] OMP_CANCELLATION='FALSE'\n",
" [host] OMP_DEFAULT_DEVICE='0'\n",
" [host] OMP_DISPLAY_AFFINITY='FALSE'\n",
" [host] OMP_DISPLAY_ENV='TRUE'\n",
" [host] OMP_DYNAMIC='FALSE'\n",
" [host] OMP_MAX_ACTIVE_LEVELS='1'\n",
" [host] OMP_MAX_TASK_PRIORITY='0'\n",
" [host] OMP_NESTED: deprecated; max-active-levels-var=1\n",
" [host] OMP_NUM_TEAMS='0'\n",
" [host] OMP_NUM_THREADS: value is not defined\n",
" [host] OMP_PROC_BIND='false'\n",
" [host] OMP_SCHEDULE='static'\n",
" [host] OMP_STACKSIZE='8176k'\n",
" [host] OMP_TARGET_OFFLOAD=DEFAULT\n",
" [host] OMP_TEAMS_THREAD_LIMIT='0'\n",
" [host] OMP_THREAD_LIMIT='2147483647'\n",
" [host] OMP_TOOL='enabled'\n",
" [host] OMP_TOOL_LIBRARIES: value is not defined\n",
" [host] OMP_TOOL_VERBOSE_INIT: value is not defined\n",
" [host] OMP_WAIT_POLICY='PASSIVE'\n",
"OPENMP DISPLAY ENVIRONMENT END\n",
"\n",
"\n"
]
}
],
"source": [
"det_lightning = hs.HSDetectionLightning(rec, params_lightning)\n",
"spikes = det_lightning.DetectFromRaw()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of spikes found by legacy:10692\n",
"Number of spikes found by lightning:10092\n"
]
}
],
"source": [
"print(f\"Number of spikes found by legacy:{det.spikes.t.values.shape[0]}\")\n",
"print(f\"Number of spikes found by lightning:{det_lightning.spikes.t.values.shape[0]}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'lightning')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,4))\n",
"ax = plt.subplot(121)\n",
"det.PlotAll(s=1, ax=ax)\n",
"plt.axis('equal')\n",
"plt.title('legacy')\n",
"ax = plt.subplot(122)\n",
"det_lightning.PlotAll(s=1, ax=ax)\n",
"plt.axis(\"equal\")\n",
"plt.title(\"lightning\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "hs2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment