Created
August 5, 2024 17:22
-
-
Save CBroz1/ee34516885ac8d3ea6cc00043479030c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""V1""" | |
from functools import reduce | |
from typing import List, Union | |
import numpy as np | |
import spikeinterface as si | |
from spikeinterface.core.job_tools import ChunkRecordingExecutor, ensure_n_jobs | |
from spyglass.common.common_interval import ( | |
_union_concat, | |
interval_from_inds, | |
interval_list_complement, | |
) | |
from spyglass.spikesorting.utils import ( | |
_check_artifact_thresholds, | |
_compute_artifact_chunk, | |
_init_artifact_worker, | |
) | |
from spyglass.utils import logger | |
def _get_artifact_times( | |
recording: si.BaseRecording, | |
sort_interval_valid_times: List[List], | |
zscore_thresh: Union[float, None] = None, | |
amplitude_thresh_uV: Union[float, None] = None, | |
proportion_above_thresh: float = 1.0, | |
removal_window_ms: float = 1.0, | |
verbose: bool = False, | |
**job_kwargs, | |
): | |
valid_timestamps = recording.get_times() | |
# if both thresholds are None, we skip artifract detection | |
if amplitude_thresh_uV is zscore_thresh is None: | |
logger.info( | |
"Amplitude and zscore thresholds are both None, " | |
+ "skipping artifact detection" | |
) | |
return np.asarray( | |
[valid_timestamps[0], valid_timestamps[-1]] | |
), np.asarray([]) | |
# verify threshold parameters | |
( | |
amplitude_thresh_uV, | |
zscore_thresh, | |
proportion_above_thresh, | |
) = _check_artifact_thresholds( | |
amplitude_thresh=amplitude_thresh_uV, | |
zscore_thresh=zscore_thresh, | |
proportion_above_thresh=proportion_above_thresh, | |
) | |
# detect frames that are above threshold in parallel | |
n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) | |
logger.info(f"Using {n_jobs} jobs...") | |
if n_jobs == 1: | |
init_args = ( | |
recording, | |
zscore_thresh, | |
amplitude_thresh_uV, | |
proportion_above_thresh, | |
) | |
else: | |
init_args = ( | |
recording.to_dict(), | |
zscore_thresh, | |
amplitude_thresh_uV, | |
proportion_above_thresh, | |
) | |
executor = ChunkRecordingExecutor( | |
recording=recording, | |
func=_compute_artifact_chunk, | |
init_func=_init_artifact_worker, | |
init_args=init_args, | |
verbose=verbose, | |
handle_returns=True, | |
job_name="detect_artifact_frames", | |
**job_kwargs, | |
) | |
artifact_frames = executor.run() | |
artifact_frames = np.concatenate(artifact_frames) | |
# turn ms to remove total into s to remove from either side of each | |
# detected artifact | |
half_removal_window_s = removal_window_ms / 2 / 1000 | |
if len(artifact_frames) == 0: | |
recording_interval = np.asarray( | |
[[valid_timestamps[0], valid_timestamps[-1]]] | |
) | |
artifact_times_empty = np.asarray([]) | |
logger.warn("No artifacts detected.") | |
return recording_interval, artifact_times_empty | |
# convert indices to intervals | |
artifact_intervals = interval_from_inds(artifact_frames) | |
# convert to seconds and pad with window | |
artifact_intervals_s = np.zeros( | |
(len(artifact_intervals), 2), dtype=np.float64 | |
) | |
for interval_idx, interval in enumerate(artifact_intervals): | |
interv_ind = [ | |
np.searchsorted( | |
valid_timestamps, | |
valid_timestamps[interval[0]] - half_removal_window_s, | |
), | |
np.searchsorted( | |
valid_timestamps, | |
valid_timestamps[interval[1]] + half_removal_window_s, | |
), | |
] | |
artifact_intervals_s[interval_idx] = [ | |
valid_timestamps[interv_ind[0]], | |
valid_timestamps[interv_ind[1]], | |
] | |
# make the artifact intervals disjoint | |
if len(artifact_intervals_s) > 1: | |
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) | |
# find non-artifact intervals in timestamps | |
artifact_removed_valid_times = interval_list_complement( | |
sort_interval_valid_times, artifact_intervals_s, min_length=1 | |
) | |
artifact_removed_valid_times = reduce( | |
_union_concat, artifact_removed_valid_times | |
) | |
return artifact_removed_valid_times, artifact_intervals_s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment