Created
May 4, 2023 10:18
-
-
Save h-mayorquin/b6ff883d66e9148d31467f7865eb83ff to your computer and use it in GitHub Desktop.
Recall for peak detectionr
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def calculate_peaks_spike_recall(peaks, sorting_ground_truth, tolerance_ms=0.4): | |
""" | |
Calculate the spike recall of the peaks (which are the output of a peak detection method) | |
against a ground truth of spike_trains. This function is used to test the quality of a peak detection method | |
in the context of a specific sorting. | |
Recall close to 1 means that all the spike in the spike_train are present in a peak whereas recall | |
close to 0 means that no spike in the spike_train are present in a peak. | |
More technically, this calculates the number of True positives divided by | |
the total number of positive examples (i.e. the number of spikes) | |
within a given tolerance. | |
The tolerance is given in milliseconds and means that the detected peak | |
time should be less than the tolerance away from the real spike time. | |
""" | |
sample_indices = peaks["sample_index"] | |
spike_trains = sorting_ground_truth.get_all_spike_trains()[0][0] | |
sampling_freuency = sorting_ground_truth.get_sampling_frequency() | |
tolerance_number_samples = tolerance_ms * sampling_freuency / 1_000.0 | |
are_spikes_close_any_peaks = np.any(np.abs(sample_indices - spike_trains[:, np.newaxis]) < tolerance_number_samples, axis=1) | |
return are_spikes_close_any_peaks.mean() | |
# Test with this | |
peaks = peaks_by_channel_np | |
method_recall = calculate_peaks_spike_recall(peaks, sorting_ground_truth) | |
pytest.approx(method_recall, 1.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment