Created
May 20, 2024 16:41
-
-
Save CBroz1/268fe2b36613c46c19f515d41636942d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import itertools | |
import operator | |
import os | |
import pickle | |
from collections import namedtuple | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import scipy as sp | |
import spikeinterface as si | |
from spyglass.spikesorting import ( | |
SortingviewWorkspace, | |
SpikeSorting, | |
SpikeSortingRecording, | |
) | |
from spyglass.spikesorting.spikesorting_curation import ( | |
AutomaticCuration, | |
AutomaticCurationParameters, | |
AutomaticCurationSelection, | |
CuratedSpikeSorting, | |
CuratedSpikeSortingSelection, | |
Curation, | |
MetricParameters, | |
MetricSelection, | |
QualityMetrics, | |
WaveformParameters, | |
Waveforms, | |
WaveformSelection, | |
) | |
def overlap(x, y): | |
z = list(x) + list(y) | |
if not all(np.isfinite(z)): | |
raise Exception(f"All elements must be finite") | |
if not all(np.asarray(z) >= 0): | |
raise Exception(f"All elements must be nonnegative") | |
return ( | |
2 * np.sum(np.min(np.vstack((x, y)), axis=0)) / (np.sum(x) + np.sum(y)) | |
) | |
def add_colorbar( | |
img, fig, ax, cbar_location="left", size="5%", pad_factor=0.05 | |
): | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
divider = make_axes_locatable(ax) | |
pad_factor *= fig.get_size_inches()[0] # scale pad based on figure width | |
cax = divider.append_axes(cbar_location, size=size, pad=pad_factor) | |
cbar = fig.colorbar(img, ax=[ax], cax=cax) | |
return cbar | |
def make_param_name(param_values): | |
return "_1".join([str(x) for x in param_values]) | |
def format_nwb_file_name(nwb_file_name): | |
return nwb_file_name.split("_")[0] | |
def get_google_spreadsheet( | |
service_account_dir, | |
service_account_json, | |
spreadsheet_key, | |
spreadsheet_tab_name, | |
): | |
scope = ["https://spreadsheets.google.com/feeds"] | |
# Change to directory with service account credentials json | |
os.chdir(service_account_dir) | |
# Get service account credentials from json file | |
service_account_credentials = ( | |
ServiceAccountCredentials.from_json_keyfile_name( | |
service_account_json, scope | |
) | |
) | |
# Get spreadsheet | |
client_obj = gspread.authorize(service_account_credentials) | |
spreadsheet_obj = client_obj.open_by_key(spreadsheet_key) | |
worksheet = spreadsheet_obj.worksheet(spreadsheet_tab_name) | |
return worksheet.get_all_values() | |
def check_one_none(x, list_element_names=""): | |
num_none = len([x_i for x_i in x if x_i is None]) | |
if num_none != 1: | |
raise Exception( | |
f"Need exactly one None in passed arguments {list_element_names}" | |
f" but got {num_none} Nones" | |
) | |
def get_curation_spreadsheet(subject_id, date, tolerate_no_notes=True): | |
# Get data from google spreadsheet | |
service_account_dir = "" | |
service_account_json = "" | |
spreadsheet_key = "" | |
spreadsheet_tab_name = f"{subject_id}{date}" | |
column_names = np.asarray( | |
[ | |
"sort_group_id", | |
"unit_id_1", | |
"unit_id_2", | |
"merge_type", | |
"notes", | |
"label", | |
"potential action items", | |
] | |
) | |
try: | |
table = np.asarray( | |
get_google_spreadsheet( | |
service_account_dir, | |
service_account_json, | |
spreadsheet_key, | |
spreadsheet_tab_name, | |
) | |
) | |
except: | |
failure_message = f"Could not get google spreadsheet with curation notes for {subject_id}{date}" | |
if tolerate_no_notes: | |
print(failure_message) | |
return pd.DataFrame(columns=column_names) | |
else: | |
raise Exception(failure_message) | |
# Get labels as df | |
# ...Get index of row where labels start: the one after column names | |
row_idx = ( | |
unpack_single_element( | |
np.where(np.prod(table == column_names, axis=1))[0] | |
) | |
+ 1 | |
) | |
# ...Convert to dataframe. Here, convert datatype from string as appropriate | |
def _convert_curation_spreadsheet_dtype(row, column_name): | |
int_column_names = ["sort_group_id", "unit_id_1", "unit_id_2"] | |
bool_column_names = ["label"] | |
bool_map = {"yes": True, "no": False, "unsure": None} | |
if column_name in int_column_names: | |
return row.astype(int) | |
elif column_name in bool_column_names: | |
return [ | |
bool_map[x.strip()] for x in row | |
] # strip whitespace and convert to bool variable | |
return row | |
return pd.DataFrame.from_dict( | |
{ | |
column_name: _convert_curation_spreadsheet_dtype(row, column_name) | |
for column_name, row in zip(column_names, table[row_idx:].T) | |
} | |
) | |
def get_cluster_data_file_name( | |
nwb_file_name, | |
sort_interval_name, | |
sorter, | |
preproc_params_name, | |
curation_id, | |
sort_group_id=None, | |
target_region=None, | |
): | |
# Check inputs | |
check_one_none([sort_group_id, target_region]) | |
electrode_text = target_region | |
if sort_group_id is not None: | |
electrode_text = sort_group_id | |
return make_param_name( | |
[ | |
format_nwb_file_name(nwb_file_name), | |
sort_interval_name, | |
sorter, | |
preproc_params_name, | |
curation_id, | |
electrode_text, | |
] | |
) | |
def load_curation_data( | |
save_dir, | |
nwb_file_name, | |
sort_interval_name, | |
sorter="mountainsort4", | |
preproc_params_name="franklab_tetrode_hippocampus", | |
sort_group_ids=None, | |
target_region=None, | |
curation_id=1, | |
verbose=True, | |
overwrite_quantities=True, | |
): | |
# Check inputs | |
check_one_none([sort_group_ids, target_region]) | |
if verbose: | |
print(f"Loading curation data for {nwb_file_name}...") | |
cd_make_if_nonexistent(save_dir) | |
if target_region is not None: | |
file_name_save = get_cluster_data_file_name( | |
nwb_file_name, | |
sort_interval_name, | |
sorter, | |
preproc_params_name, | |
curation_id, | |
target_region=target_region, | |
) | |
return pickle.load(open(file_name_save, "rb")) | |
sort_groups_data = dict() | |
for sort_group_id in sort_group_ids: | |
file_name_save = get_cluster_data_file_name( | |
nwb_file_name, | |
sort_interval_name, | |
sorter, | |
preproc_params_name, | |
curation_id, | |
sort_group_id=sort_group_id, | |
) | |
# Continue if file doesnt exist | |
if not os.path.exists(file_name_save): | |
continue | |
# Store sort group data | |
sort_groups_data[sort_group_id] = pickle.load( | |
open(file_name_save, "rb") | |
) | |
# TODO: overwrite files with version with this calculation, and then delete this from here | |
default_params = get_correlogram_default_params() | |
correlogram_max_dt, correlogram_min_dt = ( | |
default_params["max_dt"], | |
default_params["min_dt"], | |
) | |
for data in sort_groups_data.values(): | |
data["correlogram_isi_violation_ratios"] = ( | |
get_correlogram_isi_violation_ratios( | |
data, max_dt=correlogram_max_dt, min_dt=correlogram_min_dt | |
) | |
) | |
# !!! TEMPORARY UNTIL MOVE THIS TO make_curation_data | |
if overwrite_quantities: | |
# Convert sort group ID from string to int | |
sort_groups_data = { | |
int(sort_group_id): sort_group_data | |
for sort_group_id, sort_group_data in sort_groups_data.items() | |
} | |
# Add unit names to sort group data | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
if "unit_ids" not in sort_group_data: | |
sort_group_data["unit_ids"] = list( | |
sort_group_data["spike_times"].keys() | |
) | |
# Get correlogram quantities | |
print("Loading correlogram quantities...") | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
sort_group_data["correlogram_asymmetries"] = ( | |
get_correlogram_asymmetries(sort_group_data["correlograms"]) | |
) | |
sort_group_data["correlogram_asymmetry_directions"] = ( | |
get_correlogram_asymmetry_directions( | |
sort_group_data["correlograms"] | |
) | |
) | |
sort_group_data["correlogram_counts"] = get_correlogram_counts( | |
sort_group_data["correlograms"] | |
) | |
# Overwrite amplitude overlaps | |
print("Loading amplitude overlaps...") | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
sort_group_data["amplitude_overlaps"] = get_amplitude_overlaps( | |
sort_group_data | |
) | |
# Get amplitude size comparison | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
sort_group_data["amplitude_size_comparisons"] = ( | |
get_amplitude_size_comparisons(sort_group_data) | |
) | |
# Get burst pair amplitude correlogram asymmetry metric | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
sort_group_data["burst_pair_amplitude_timing_bools"] = ( | |
get_burst_pair_amplitude_timing_bools(sort_group_data) | |
) | |
# ISI violation percent | |
print("Loading ISI violation percent...") | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
sort_group_data["unit_pair_percent_isi_violations"] = ( | |
get_unit_pair_percent_isi_violations(sort_group_data) | |
) | |
# Valid lower amplitude fractions | |
print("Loading valid lower amplitude fractions...") | |
for sort_group_id, sort_group_data in sort_groups_data.items(): | |
print(f"on sort group {sort_group_id}...") | |
sort_group_data["valid_lower_amplitude_fractions"] = ( | |
get_valid_lower_amplitude_fractions(sort_group_data) | |
) | |
sort_group_data["unit_merge_valid_lower_amplitude_fractions"] = ( | |
get_unit_merge_valid_lower_amplitude_fractions(sort_group_data) | |
) | |
# !!!!!!!! | |
return { | |
"nwb_file_name": nwb_file_name, | |
"sort_interval_name": sort_interval_name, | |
"sorter": sorter, | |
"preproc_params_name": preproc_params_name, | |
"n_sort_groups": len(sort_group_ids), | |
"sort_groups": sort_groups_data, | |
} | |
def make_curation_data( | |
save_dir, | |
nwb_file_name, | |
sort_interval_name, | |
sorter="mountainsort4", | |
preproc_params_name="franklab_tetrode_hippocampus", | |
sort_group_ids=None, | |
curation_id=1, | |
get_workspace_url=True, | |
ignore_invalid_sort_group_ids=False, | |
overwrite_existing=False, | |
verbose=True, | |
): | |
# Check that key specific enough (each sort group represented no more than once) | |
key = { | |
"nwb_file_name": nwb_file_name, | |
"sort_interval_name": sort_interval_name, | |
"sorter": sorter, | |
"preproc_params_name": preproc_params_name, | |
"curation_id": curation_id, | |
} | |
valid_sort_group_ids = [ | |
x for x in (SpikeSorting & key).fetch("sort_group_id") | |
] | |
check_all_unique(valid_sort_group_ids) | |
# Define sort group ids if not passed | |
if sort_group_ids is None: | |
sort_group_ids = valid_sort_group_ids | |
# Check that passed sort group ids are valid | |
if not ignore_invalid_sort_group_ids and not set(sort_group_ids).issubset( | |
set(valid_sort_group_ids) | |
): | |
raise ValueError(f"List of sort groups includes invalid sort group IDs") | |
# Get correlogram default params | |
default_params = get_correlogram_default_params() | |
correlogram_max_dt, correlogram_min_dt = ( | |
default_params["max_dt"], | |
default_params["min_dt"], | |
) | |
# Loop through sort groups and make cluster data if does not exist or want to overwrite | |
for sort_group_id in sort_group_ids: | |
# Continue if sort group invalid and want to tolerate this | |
if ( | |
ignore_invalid_sort_group_ids | |
and sort_group_id not in valid_sort_group_ids | |
): | |
continue | |
# Continue if file already exists and dont want to overwrite | |
file_name_save = get_cluster_data_file_name( | |
nwb_file_name, | |
sort_interval_name, | |
sorter, | |
preproc_params_name, | |
curation_id, | |
sort_group_id, | |
) | |
if ( | |
os.path.exists(os.path.join(save_dir, file_name_save)) | |
and not overwrite_existing | |
): | |
print( | |
f"Cluster data exists for {nwb_file_name}, sort group {sort_group_id}; continuing" | |
) | |
continue | |
# Otherwise, make cluster data | |
if verbose: | |
print( | |
f"Making cluster data for {nwb_file_name}, sort group {sort_group_id}" | |
) | |
# Get key from CuratedSpikeSorting since will need all fields (some of which were not defined | |
# by user, e.g. team_name) to populate other tables | |
sort_group_key = ( | |
CuratedSpikeSorting & {**key, **{"sort_group_id": sort_group_id}} | |
).fetch1("KEY") | |
data = dict() # for cluster data | |
# Make key for getting whitened waveforms. Note that here we use a curation id of ZERO, regardless | |
# of what curation_id was passed | |
waveforms_key = copy.deepcopy(sort_group_key) | |
# waveforms_key.update({'curation_id': 0, | |
# 'waveform_params_name': 'RSN_whitened_float'}) | |
waveforms_key.update( | |
{"curation_id": 0, "waveform_params_name": "5k_whitened_float_1"} | |
) | |
# Populate waveforms tables if no entry | |
if not (Waveforms & waveforms_key): | |
if verbose: | |
print(f"Populating Waveforms table with key {waveforms_key}...") | |
WaveformSelection.insert1(waveforms_key, skip_duplicates=True) | |
Waveforms.populate([(WaveformSelection & waveforms_key).proj()]) | |
# Get workspace URL if indicated | |
if get_workspace_url: | |
data["workspace_url"] = SortingviewWorkspace().url(sort_group_key) | |
# Get timestamps | |
if verbose: | |
print(f"Getting timestamps...") | |
recording_path = (SpikeSortingRecording & sort_group_key).fetch1( | |
"recording_path" | |
) | |
recording = si.load_extractor(recording_path) | |
timestamps_raw = SpikeSortingRecording._get_recording_timestamps( | |
recording | |
) | |
# Get total recording duration in seconds | |
data["recording_duration"] = recording.get_total_duration() | |
# Get spikes data | |
if verbose: | |
print(f"Getting spikes data...") | |
# ...First get valid unit IDs, for the passed curation_id and metric restrictions. | |
# If not unit IDs for given sort group, continue | |
css_entry = unpack_single_element( | |
( | |
CuratedSpikeSorting | |
& {**sort_group_key, **{"curation_id": curation_id}} | |
).fetch_nwb() | |
) | |
if "units" not in css_entry: | |
continue | |
units_df = css_entry["units"] | |
valid_unit_ids = units_df.index | |
# ...Get unit metrics | |
metric_names = [ | |
"snr", | |
"isi_violation", | |
"nn_isolation", | |
"nn_noise_overlap", | |
] # desired metrics | |
for metric_name in metric_names: | |
# Continue if metric name not in curated spike sorting entry | |
if metric_name not in units_df: | |
continue | |
data[metric_name] = units_df[metric_name].to_dict() | |
# ...Get waveform extractor, which will be used to get other quantities | |
we = (Waveforms & waveforms_key).load_waveforms( | |
waveforms_key | |
) # waveform extractor | |
data["sampling_frequency"] = we.sorting.get_sampling_frequency() | |
data["unit_ids"] = valid_unit_ids | |
data["n_clusters"] = len(valid_unit_ids) | |
data["n_channels"] = len(we.recording.get_channel_ids()) | |
data["waveform_window"] = np.arange(-we.nbefore, we.nafter) | |
# IMPORTANT NOTE: WAVEFORMS AND SPIKE TIMES ARE SUBSAMPLED (SEEMS MAX IS AT 20000). Happens in line below. | |
waveform_data = { | |
unit_id: we.get_waveforms(unit_id, with_index=True) | |
for unit_id in valid_unit_ids | |
} | |
spike_samples = { | |
unit_id: we.sorting.get_unit_spike_train(unit_id=unit_id) | |
for unit_id in valid_unit_ids | |
} | |
# TODO: understand the line below | |
data["waveforms"] = { | |
unit_id: np.swapaxes(wv[0], 0, 2) | |
for unit_id, wv in waveform_data.items() | |
} | |
data["waveform_indices"] = { | |
unit_id: np.array(list(zip(*wv[1]))[0]).astype(int) | |
for unit_id, wv in waveform_data.items() | |
} | |
# ...Get spike times | |
data["spike_times"] = { | |
unit_id: timestamps_raw[samples[data["waveform_indices"][unit_id]]] | |
for unit_id, samples in spike_samples.items() | |
} | |
# ...Get average waveforms | |
data["average_waveforms"] = get_average_waveforms(data["waveforms"]) | |
# ...Get peak channels | |
data["peak_channels"] = get_peak_channels(data["average_waveforms"]) | |
# ...Get waveform amplitudes | |
data["amplitudes"] = get_waveform_amplitudes(data["waveforms"]) | |
# ...Get amplitude size comparison | |
data["amplitude_size_comparisons"] = get_amplitude_size_comparisons( | |
data | |
) | |
# Get cosine similarity | |
if verbose: | |
print(f"Getting cosine similarities...") | |
data["cosine_similarities"] = get_cosine_similarities( | |
data["average_waveforms"] | |
) | |
# Get correlogram quantities | |
if verbose: | |
print(f"Getting correlograms...") | |
data["correlograms"] = get_correlograms( | |
data["spike_times"], | |
max_dt=correlogram_max_dt, | |
min_dt=correlogram_min_dt, | |
) | |
data["correlogram_isi_violation_ratios"] = ( | |
get_correlogram_isi_violation_ratios( | |
data, max_dt=correlogram_max_dt, min_dt=correlogram_min_dt | |
) | |
) | |
data["correlogram_asymmetries"] = get_correlogram_asymmetries( | |
data["correlograms"] | |
) | |
data["correlogram_asymmetry_directions"] = ( | |
get_correlogram_asymmetry_directions(data["correlograms"]) | |
) | |
data["correlogram_counts"] = get_correlogram_counts( | |
data["correlograms"] | |
) | |
data["correlogram_min_dt"] = correlogram_max_dt | |
data["correlogram_max_dt"] = correlogram_max_dt | |
# Get amplitude overlap | |
if verbose: | |
print(f"Getting amplitude overlaps...") | |
data["amplitude_overlaps"] = get_amplitude_overlaps(data) | |
# Get burst pair amplitude correlogram asymmetry metric | |
data["burst_pair_amplitude_timing_bools"] = ( | |
get_burst_pair_amplitude_timing_bools(data) | |
) | |
# Get ISI violation percent for merged unit pairs | |
data["unit_pair_percent_isi_violations"] = ( | |
get_unit_pair_percent_isi_violations(data) | |
) | |
# Get amplitude decrement metrics | |
if verbose: | |
print(f"Getting amplitude decrement quantities...") | |
for max_dt in [0.015, 0.4]: | |
data[f"amplitude_decrements_{max_dt}"] = ( | |
get_unit_amplitude_decrements(data, max_dt) | |
) | |
data[f"unit_merge_amplitude_decrements_{max_dt}"] = ( | |
get_unit_merge_amplitude_decrements(data, max_dt) | |
) | |
data[f"amplitude_decrement_changes_{max_dt}"] = ( | |
get_amplitude_decrement_changes(data, max_dt) | |
) | |
# Valid lower amplitude fractions | |
if verbose: | |
print(f"Getting valid lower amplitude fractions...") | |
data["valid_lower_amplitude_fractions"] = ( | |
get_valid_lower_amplitude_fractions(data) | |
) | |
data["unit_merge_valid_lower_amplitude_fractions"] = ( | |
get_unit_merge_valid_lower_amplitude_fractions(data) | |
) | |
# Save data | |
cd_make_if_nonexistent(save_dir) | |
if verbose: | |
print(f"Saving {file_name_save} in {save_dir}...") | |
pickle.dump(data, open(file_name_save, "wb")) # save data | |
# need to update this for each user | |
def get_curation_data_save_dir(subject_id): | |
return f"/cumulus/mcoulter/curation_data/{subject_id}" | |
def make_curation_data_wrapper( | |
subject_ids, | |
dates, | |
sort_interval_name="raw data valid times no premaze no home", | |
preproc_params_name="franklab_tetrode_hippocampus", | |
sorter="mountainsort4", | |
sort_group_ids=None, | |
get_workspace_url=False, | |
curation_id=1, | |
ignore_invalid_sort_group_ids=False, | |
overwrite_existing=False, | |
verbose=True, | |
): | |
# Make curation data | |
for subject_id, date in zip(subject_ids, dates): | |
# Get nwb file name | |
nwb_file_name = nwbf_name_from_subject_id_date(subject_id, date) | |
# Define directory to save data in | |
save_dir = get_curation_data_save_dir(subject_id) | |
# Make curation data | |
make_curation_data( | |
save_dir=save_dir, | |
nwb_file_name=nwb_file_name, | |
sort_interval_name=sort_interval_name, | |
sorter=sorter, | |
preproc_params_name=preproc_params_name, | |
sort_group_ids=sort_group_ids, | |
curation_id=curation_id, | |
get_workspace_url=get_workspace_url, | |
ignore_invalid_sort_group_ids=ignore_invalid_sort_group_ids, | |
overwrite_existing=overwrite_existing, | |
verbose=verbose, | |
) | |
def _compute_cluster_data(func_name, data_in): | |
data_out = {cluster: None for cluster in data_in.keys()} | |
for cluster, data in data_in.items(): | |
data_out[cluster] = func_name(data) | |
return data_out | |
def _compute_pairwise_cluster_data( | |
func_name, data_in, nested_dict=False, kwargs=None | |
): | |
# Get inputs if not passed | |
if kwargs is None: | |
kwargs = {} | |
# Initialize output dictionary | |
data_out = { | |
cluster_1: {cluster_2: None for cluster_2 in data_in.keys()} | |
for cluster_1 in data_in.keys() | |
} | |
if nested_dict: | |
for cluster_1 in data_in.keys(): | |
for cluster_2 in data_in[cluster_1].keys(): | |
data_out[cluster_1][cluster_2] = func_name( | |
data_in[cluster_1][cluster_2], **kwargs | |
) | |
else: | |
for cluster_1, data_1 in data_in.items(): | |
for cluster_2, data_2 in data_in.items(): | |
data_out[cluster_1][cluster_2] = func_name( | |
data_1, data_2, **kwargs | |
) | |
return data_out | |
def _compute_average_waveform(wv): | |
wv_avg = np.mean(wv, axis=2) | |
return wv_avg | |
def _compute_peak_channel(wv_avg): | |
idx = np.argmax(_compute_waveform_amplitude(wv_avg)) | |
return idx | |
def _compute_waveform_amplitude(wv): | |
amp = np.max(wv, axis=1) - np.min(wv, axis=1) | |
return amp | |
def _compute_amplitude_overlaps(data, unit_1, unit_2, bin_width=0.1): | |
# Find peak amplitude channel for each unit | |
unit_ids = [unit_1, unit_2] | |
peak_channels = np.unique( | |
[data["peak_channels"][unit_id] for unit_id in unit_ids] | |
) | |
# For each unique peak amplitude channel, find overlap of normalized histograms of | |
# amplitude distribution across units | |
overlaps = [] # overlap across peak channels | |
for peak_channel in peak_channels: # peak channels | |
# Get amplitudes for units | |
unit_amplitudes = np.asarray( | |
[ | |
unpack_single_element( | |
_compute_waveform_amplitude( | |
data["waveforms"][unit_id][[peak_channel], :, :] | |
) | |
) | |
for unit_id in unit_ids | |
] | |
) | |
# Use minimum and maximum amplitude seen across units to form histogram bins | |
concatenated_unit_amplitudes = np.concatenate(unit_amplitudes) | |
bin_edges = np.arange( | |
np.min(concatenated_unit_amplitudes), | |
np.max(concatenated_unit_amplitudes) + bin_width, | |
bin_width, | |
) | |
# Find overlap between normalized histograms | |
overlaps.append( | |
overlap( | |
*[ | |
np.histogram(amplitudes, bin_edges, density=True)[0] | |
for amplitudes in unit_amplitudes | |
] | |
) | |
) | |
# Take average of overlaps across unit peak amplitude channels | |
return np.mean(overlaps) | |
def _compare_amplitude_size(data, unit_1, unit_2): | |
unit_1_mean = np.mean( | |
data["amplitudes"][unit_1][data["peak_channels"][unit_1]] | |
) | |
unit_2_mean = np.mean( | |
data["amplitudes"][unit_2][data["peak_channels"][unit_1]] | |
) | |
if unit_1_mean < unit_2_mean: | |
return -1 | |
if unit_1_mean == unit_2_mean: | |
return 0 | |
if unit_1_mean > unit_2_mean: | |
return 1 | |
def _compute_cosine_similarity(wv_avg_1, wv_avg_2): | |
wv_avg_1, wv_avg_2 = (np.ravel(wv_avg) for wv_avg in (wv_avg_1, wv_avg_2)) | |
wv_avg_nrm_1, wv_avg_nrm_2 = ( | |
wv_avg / np.linalg.norm(wv_avg, axis=0) | |
for wv_avg in (wv_avg_1, wv_avg_2) | |
) | |
sim = np.dot(wv_avg_nrm_1, wv_avg_nrm_2) | |
return sim | |
def get_correlogram_default_params(): | |
return {"max_dt": 0.5, "min_dt": 0} | |
def _compute_correlogram(spk_times_1, spk_times_2, max_dt=None, min_dt=None): | |
# Get inputs if not passed | |
if max_dt is None: | |
max_dt = get_correlogram_default_params()["max_dt"] | |
if min_dt is None: | |
min_dt = get_correlogram_default_params()["min_dt"] | |
time_diff = ( | |
np.tile(spk_times_1, (spk_times_2.size, 1)) - spk_times_2[:, np.newaxis] | |
) | |
ind = np.logical_and( | |
np.abs(time_diff) > min_dt, np.abs(time_diff) <= max_dt | |
) | |
time_diff = np.sort(time_diff[ind]) | |
return time_diff | |
def _compute_correlogram_count( | |
time_diff, min_dt=-200 / 1000, max_dt=200 / 1000 | |
): | |
return np.sum(np.logical_and(time_diff > min_dt, time_diff < max_dt)) | |
def _compute_correlogram_asymmetry_direction( | |
time_diff, min_dt=-200 / 1000, max_dt=200 / 1000 | |
): | |
neg_count = np.sum(np.logical_and(time_diff > min_dt, time_diff < 0)) | |
pos_count = np.sum(np.logical_and(time_diff > 0, time_diff < max_dt)) | |
if neg_count > pos_count: | |
return -1 | |
if neg_count == pos_count: | |
return 0 | |
if pos_count > neg_count: | |
return 1 | |
def _compute_correlogram_asymmetry( | |
time_diff, min_dt=-200 / 1000, max_dt=200 / 1000 | |
): | |
zero_count = np.sum(time_diff == 0) | |
neg_count = np.sum(np.logical_and(time_diff > min_dt, time_diff < 0)) | |
pos_count = np.sum(np.logical_and(time_diff > 0, time_diff < max_dt)) | |
asym = (np.max([neg_count, pos_count]) + zero_count / 2) / ( | |
zero_count + neg_count + pos_count | |
) | |
return asym | |
def percent_isi_violations(spike_train, isi_threshold): | |
isis = np.diff(spike_train) | |
num_isi_violations = np.sum(isis < isi_threshold) | |
return 100 * num_isi_violations / len(isis) | |
def _compute_correlogram_isi_violation_ratio( | |
correlogram, correlogram_window_width, isi_threshold=None | |
): | |
# Get inputs if not passed | |
if isi_threshold is None: | |
isi_threshold = 0.0015 | |
# Find violations in correlogram | |
invalid_bool = abs(correlogram) < isi_threshold | |
# Compute fraction of correlogram that is violations | |
correlogram_isi_violation = np.sum(invalid_bool) / len(invalid_bool) | |
# Calculate expected violation ratio if correlogram uniform | |
uniform_violation = (isi_threshold * 2) * correlogram_window_width | |
# Return ratio of actual violation ratio to ratio expected if uniform correlogram | |
return correlogram_isi_violation / uniform_violation | |
def _burst_pair_amplitude_timing_bool(data, unit_1, unit_2): | |
return ( | |
data["amplitude_size_comparisons"][unit_1][unit_2] | |
* data["correlogram_asymmetry_directions"][unit_1][unit_2] | |
< 0 | |
) | |
# Amplitude decrement | |
def _channel_amplitudes(data, unit_id, channel): | |
return unpack_single_element( | |
_compute_waveform_amplitude(data["waveforms"][unit_id][[channel], :, :]) | |
) | |
def _time_diff(spike_times, min_dt, max_dt): | |
time_diff = ( | |
np.tile(spike_times, (spike_times.size, 1)) - spike_times[:, np.newaxis] | |
) | |
ind = np.logical_and( | |
np.abs(time_diff) > min_dt, np.abs(time_diff) <= max_dt | |
) | |
return time_diff, ind | |
# slow step?? | |
# this takes about 10 secs - seems like is has to run 4x for each cluster pair | |
# slow steps 0->1 and 1->2. note: this is much slower for tetrodes with many spikes | |
# if you reduce spikes to 5k from 20k, speed up is 10-16x! | |
def _compute_amplitude_decrement(spike_times, amplitudes, max_dt=None): | |
# print('_compute_unit_decrement',datetime.datetime.now()) | |
# Get inputs if not passed | |
if max_dt is None: | |
max_dt = 0.015 | |
# print('_compute_unit_decrement 0',datetime.datetime.now(),len(spike_times)) | |
time_diff, ind = _time_diff(spike_times, 0, max_dt) | |
# print('_compute_unit_decrement 1',datetime.datetime.now()) | |
amplitude_diff = ( | |
np.tile(amplitudes, (amplitudes.size, 1)) - amplitudes[:, np.newaxis] | |
) | |
# print('_compute_unit_decrement 2',datetime.datetime.now()) | |
valid_time_diff = time_diff[ind] | |
# print('_compute_unit_decrement 3',datetime.datetime.now()) | |
valid_amplitude_diff = amplitude_diff[ind] | |
# print('_compute_unit_decrement 4',datetime.datetime.now()) | |
# Return nan if fewer than two valid samples, since in this case cannot calculate correlation | |
if len(valid_time_diff) < 2: | |
return np.nan | |
return sp.stats.pearsonr(valid_time_diff, valid_amplitude_diff)[0] | |
def _compute_unit_amplitude_decrement(data, unit_id, max_dt=None): | |
return _compute_amplitude_decrement( | |
spike_times=data["spike_times"][unit_id], | |
amplitudes=_channel_amplitudes( | |
data, unit_id, data["peak_channels"][unit_id] | |
), | |
max_dt=max_dt, | |
) | |
# this is the slow step: runs computation 4x per cluster pair | |
# maybe this step could be parallelized because the same computation is run 4 times | |
def _compute_unit_merge_amplitude_decrement(data, unit_1, unit_2, max_dt=None): | |
unit_ids = [unit_1, unit_2] | |
# print('_compute_unit_merge_amplitude_decrement',datetime.datetime.now(),'units',unit_1,unit_2) | |
return np.mean( | |
[ | |
_compute_amplitude_decrement( | |
spike_times=np.concatenate( | |
[data["spike_times"][unit_id] for unit_id in unit_ids] | |
), | |
amplitudes=np.concatenate( | |
[ | |
_channel_amplitudes( | |
data, | |
unit_id, | |
data["peak_channels"][peak_channel_unit_id], | |
) | |
for unit_id in unit_ids | |
] | |
), | |
max_dt=max_dt, | |
) | |
for peak_channel_unit_id in unit_ids | |
] | |
) | |
def _compute_amplitude_decrement_change(data, unit_1, unit_2, max_dt): | |
# print('_compute_amplitude_decrement_change',datetime.datetime.now()) | |
# Get average amplitude decrement across the two units | |
unit_amplitude_decrement = np.mean( | |
[ | |
data[f"amplitude_decrements_{max_dt}"][unit_id] | |
for unit_id in [unit_1, unit_2] | |
] | |
) | |
# Get amplitude decrement metric for merged case | |
unit_merge_amplitude_decrement = data[ | |
f"unit_merge_amplitude_decrements_{max_dt}" | |
][unit_1][unit_2] | |
return unit_merge_amplitude_decrement - unit_amplitude_decrement | |
# Valid lower amplitude fraction | |
def _compute_valid_lower_amplitude_fraction( | |
spike_times, amplitudes, percentile=None, valid_window=None | |
): | |
# Get inputs if not passed | |
if percentile is None: | |
percentile = 5 | |
if valid_window is None: | |
valid_window = 0.5 | |
# Get data value at passed percentile | |
threshold = np.percentile(amplitudes, percentile) | |
# Threshold data | |
below_threshold_spike_times = spike_times[amplitudes < threshold] | |
above_threshold_spike_times = spike_times[amplitudes >= threshold] | |
# Find fraction of lower amplitude spikes that have an upper amplitude spike within some amount of time | |
below_threshold_tile = np.tile( | |
below_threshold_spike_times, (len(above_threshold_spike_times), 1) | |
) | |
above_threshold_tile = np.tile( | |
above_threshold_spike_times, (len(below_threshold_spike_times), 1) | |
).T | |
spike_time_differences = above_threshold_tile - below_threshold_tile | |
valid_bool = np.sum(abs(spike_time_differences) < valid_window, axis=0) > 0 | |
# Return quantities | |
fraction_lower_amplitude_valid = np.sum(valid_bool) / len(valid_bool) | |
valid_lower_amplitude_spike_times = below_threshold_spike_times[valid_bool] | |
valid_lower_amplitudes = amplitudes[amplitudes < threshold][valid_bool] | |
return ( | |
fraction_lower_amplitude_valid, | |
valid_lower_amplitude_spike_times, | |
valid_lower_amplitudes, | |
) | |
def _compute_unit_merge_valid_lower_amplitude_fraction( | |
data, unit_1, unit_2, percentile=None, valid_window=None | |
): | |
unit_ids = [unit_1, unit_2] | |
return np.mean( | |
[ | |
_compute_valid_lower_amplitude_fraction( | |
spike_times=np.concatenate( | |
[data["spike_times"][unit_id] for unit_id in unit_ids] | |
), | |
amplitudes=np.concatenate( | |
[ | |
_channel_amplitudes( | |
data, | |
unit_id, | |
data["peak_channels"][peak_channel_unit_id], | |
) | |
for unit_id in unit_ids | |
] | |
), | |
percentile=percentile, | |
valid_window=valid_window, | |
)[0] | |
for peak_channel_unit_id in unit_ids | |
] | |
) | |
# Get quantities | |
def get_average_waveforms(waveforms): | |
return _compute_cluster_data(_compute_average_waveform, waveforms) | |
def get_peak_channels(average_waveforms): | |
return _compute_cluster_data(_compute_peak_channel, average_waveforms) | |
def get_waveform_amplitudes(waveforms): | |
return _compute_cluster_data(_compute_waveform_amplitude, waveforms) | |
def get_amplitude_overlaps(data): | |
unit_ids = data["unit_ids"] | |
return { | |
unit_1: { | |
unit_2: _compute_amplitude_overlaps(data, unit_1, unit_2) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def get_amplitude_size_comparisons(data): | |
unit_ids = data["unit_ids"] | |
return { | |
unit_1: { | |
unit_2: _compare_amplitude_size(data, unit_1, unit_2) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def get_cosine_similarities(average_waveforms): | |
return _compute_pairwise_cluster_data( | |
_compute_cosine_similarity, average_waveforms | |
) | |
def get_correlograms(spike_times, max_dt=None, min_dt=None): | |
return _compute_pairwise_cluster_data( | |
_compute_correlogram, | |
spike_times, | |
kwargs={"max_dt": max_dt, "min_dt": min_dt}, | |
) | |
def get_correlogram_counts(spike_time_differences, kwargs=None): | |
return _compute_pairwise_cluster_data( | |
_compute_correlogram_count, | |
spike_time_differences, | |
nested_dict=True, | |
kwargs=kwargs, | |
) | |
def get_correlogram_asymmetries(spike_time_differences, kwargs=None): | |
return _compute_pairwise_cluster_data( | |
_compute_correlogram_asymmetry, | |
spike_time_differences, | |
nested_dict=True, | |
kwargs=kwargs, | |
) | |
def get_correlogram_asymmetry_directions(spike_time_differences, kwargs=None): | |
return _compute_pairwise_cluster_data( | |
_compute_correlogram_asymmetry_direction, | |
spike_time_differences, | |
nested_dict=True, | |
kwargs=kwargs, | |
) | |
def get_burst_pair_amplitude_timing_bools(data): | |
unit_ids = data["unit_ids"] | |
return { | |
unit_1: { | |
unit_2: _burst_pair_amplitude_timing_bool(data, unit_1, unit_2) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def merge_spike_times(data, unit_1, unit_2): | |
return np.sort( | |
np.concatenate( | |
(data["spike_times"][unit_1], data["spike_times"][unit_2]) | |
) | |
) | |
def get_unit_pair_percent_isi_violations(data, isi_threshold=0.0015): | |
unit_ids = data["unit_ids"] | |
return { | |
unit_1: { | |
unit_2: percent_isi_violations( | |
merge_spike_times(data, unit_1, unit_2), isi_threshold | |
) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def get_correlogram_isi_violation_ratios( | |
data, max_dt, min_dt, isi_threshold=None | |
): | |
unit_ids = data["unit_ids"] | |
correlogram_window_width = max_dt * 2 - min_dt * 2 | |
return { | |
unit_1: { | |
unit_2: _compute_correlogram_isi_violation_ratio( | |
correlogram=data["correlograms"][unit_1][unit_2], | |
correlogram_window_width=correlogram_window_width, | |
isi_threshold=isi_threshold, | |
) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def get_unit_amplitude_decrements(data, max_dt=None): | |
return { | |
unit_id: _compute_unit_amplitude_decrement(data, unit_id, max_dt) | |
for unit_id in data["unit_ids"] | |
} | |
def get_unit_merge_amplitude_decrements(data, max_dt=None): | |
unit_ids = data["unit_ids"] | |
# print('cluster pair',unit_1,unit_2) | |
return { | |
unit_1: { | |
unit_2: _compute_unit_merge_amplitude_decrement( | |
data, unit_1, unit_2, max_dt | |
) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def get_amplitude_decrement_changes(data, max_dt): | |
unit_ids = data["unit_ids"] | |
return { | |
unit_1: { | |
unit_2: _compute_amplitude_decrement_change( | |
data, unit_1, unit_2, max_dt | |
) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
def get_valid_lower_amplitude_fractions( | |
data, percentile=None, valid_window=None | |
): | |
return { | |
unit_id: _compute_valid_lower_amplitude_fraction( | |
spike_times=data["spike_times"][unit_id], | |
amplitudes=data["amplitudes"][unit_id][ | |
data["peak_channels"][unit_id], : | |
], | |
percentile=percentile, | |
valid_window=valid_window, | |
)[0] | |
for unit_id in data["unit_ids"] | |
} | |
def get_unit_merge_valid_lower_amplitude_fractions( | |
data, percentile=None, valid_window=None | |
): | |
unit_ids = data["unit_ids"] | |
return { | |
unit_1: { | |
unit_2: _compute_unit_merge_valid_lower_amplitude_fraction( | |
data, | |
unit_1, | |
unit_2, | |
percentile=percentile, | |
valid_window=valid_window, | |
) | |
for unit_2 in unit_ids | |
} | |
for unit_1 in unit_ids | |
} | |
# Analysis | |
def get_merge_candidates(cluster_data, threshold_sets, sort_group_ids=None): | |
# Get inputs if not passed | |
if sort_group_ids is None: | |
sort_group_ids = list(cluster_data["sort_groups"].keys()) | |
# Loop through sort group IDs and apply thresholds to get merge candidates | |
merge_candidates_map = { | |
threshold_set_name: [] for threshold_set_name in threshold_sets.keys() | |
} | |
for sort_group_id in sort_group_ids: | |
data = cluster_data["sort_groups"][sort_group_id] | |
# Apply threshold sets | |
valid_bool_map = get_above_threshold_matrix_indices( | |
cluster_data, sort_group_id, threshold_sets | |
) | |
for ( | |
threshold_set_name, | |
valid_bool, | |
) in valid_bool_map.items(): # threshold sets | |
# Find indices in array corresponding to merge candidates | |
merge_candidate_idxs = list(zip(*np.where(valid_bool))) | |
# Convert merge candidate indices in array to unit IDs | |
merge_candidates_map[threshold_set_name] += [ | |
tuple( | |
[sort_group_id] | |
+ list(np.asarray(data["unit_ids"])[np.asarray(idxs)]) | |
) | |
for idxs in merge_candidate_idxs | |
] | |
return merge_candidates_map | |
def merge_plots_wrapper( | |
cluster_data, | |
threshold_sets, | |
fig_scale=0.8, | |
subplot_width=4, | |
subplot_height=3, | |
plot_merge_candidates=None, | |
): | |
# Define plot parameters | |
num_rows = 2 | |
num_columns = 4 | |
gridspec_kw = {"width_ratios": [1, 1, 4, 4]} | |
for sort_group_id, data in cluster_data["sort_groups"].items(): | |
# Apply threshold sets | |
valid_bool_map = get_above_threshold_matrix_indices( | |
cluster_data, sort_group_id, threshold_sets | |
) | |
# Continue of no passed merged candidates have current sort group | |
if plot_merge_candidates is not None: | |
if sort_group_id not in [x[0] for x in plot_merge_candidates]: | |
continue | |
# Plot matrices with pairwise metrics relevant for merging | |
plot_merge_matrices( | |
cluster_data, | |
sort_group_id, | |
valid_bool_map, | |
threshold_sets, | |
fig_scale=fig_scale, | |
) | |
# For threshold sets, plot metrics for merge candidates | |
for ( | |
threshold_name, | |
valid_bool, | |
) in valid_bool_map.items(): # threshold sets | |
# Find indices in array corresponding to merge candidates | |
merge_candidate_idxs = list(zip(*np.where(valid_bool))) | |
# Convert merge candidate indices in array to unit IDs | |
merge_candidates = [ | |
tuple(np.asarray(data["unit_ids"])[np.asarray(idxs)]) | |
for idxs in merge_candidate_idxs | |
] | |
# Loop through merge candidates and plot metrics | |
unit_colors = ["crimson", "#2196F3"] | |
for unit_1, unit_2 in merge_candidates: # units | |
if plot_merge_candidates is not None: | |
if ( | |
sort_group_id, | |
unit_1, | |
unit_2, | |
) not in plot_merge_candidates: | |
continue | |
# Initialize figure | |
fig, axes = plt.subplots( | |
num_rows, | |
num_columns, | |
figsize=( | |
num_columns * subplot_width, | |
num_rows * subplot_height, | |
), | |
gridspec_kw=gridspec_kw, | |
) | |
# Use peak channel of first unit to display data from both units | |
peak_ch = data["peak_channels"][unit_1] | |
# Leftmost subplots: average waveforms | |
for unit_id_idx, unit_id in enumerate([unit_1, unit_2]): | |
title = f"{sort_group_id}_{unit_id}" | |
if unit_id_idx == 1: | |
cosine_similarity = data["cosine_similarities"][unit_1][ | |
unit_2 | |
] | |
title = f"cosine similarity: {cosine_similarity: .2f}\n{title}" | |
gs = axes[0, unit_id_idx].get_gridspec() | |
# Remove underlying axis | |
for row_num in np.arange(0, num_rows): | |
axes[row_num, unit_id_idx].remove() | |
ax = fig.add_subplot(gs[:, unit_id_idx]) | |
# ax = axes[0, unit_id_idx] | |
plot_average_waveforms( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
title=title, | |
color=unit_colors[unit_id_idx], | |
ax=ax, | |
) | |
# Second subplot: amplitude distributions | |
ax = axes[0, 2] | |
for unit_id_idx, unit_id in enumerate([unit_1, unit_2]): | |
title = f"amplitude overlap: {data['amplitude_overlaps'][unit_1][unit_2]: .3f}" | |
plot_amplitude_distribution( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
ch=peak_ch, | |
max_amplitude=None, | |
amplitude_bin_size=2, | |
histtype="step", | |
density=True, | |
label=f"{sort_group_id}_{unit_id}", | |
color=unit_colors[unit_id_idx], | |
title=title, | |
ax=ax, | |
) | |
# Third subplot: correlograms | |
ax = axes[0, 3] | |
plot_correlogram( | |
cluster_data, | |
sort_group_id, | |
unit_1, | |
unit_2, | |
max_time_difference=200 / 1000, | |
color="gray", | |
ax=ax, | |
) | |
# Fourth subplot: amplitudes over time | |
# Use peak channel from first unit to plot amplitudes for both units | |
gs = axes[1, 2].get_gridspec() | |
# Remove underlying axes | |
for ax in axes[1, 2:]: | |
ax.remove() | |
ax = fig.add_subplot(gs[1, 2:]) | |
# Plot amplitudes over time for each unit | |
for unit_id_idx, unit_id in enumerate([unit_1, unit_2]): | |
plot_amplitude( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
peak_ch, | |
color=unit_colors[unit_id_idx], | |
ax=ax, | |
) | |
# Global title | |
fig.suptitle( | |
f"{sort_group_id}_{unit_1} vs. {sort_group_id}_{unit_2}\n{threshold_name}", | |
fontsize=20, | |
) | |
fig.tight_layout() | |
# VISUALIZATION | |
def plot_amplitude( | |
cluster_data, sort_group_id, unit_id, ch, color="black", ax=None | |
): | |
# Get inputs if not passed | |
if ax is None: | |
_, ax = plt.subplots() | |
# Plot amplitudes over time | |
data = cluster_data["sort_groups"][sort_group_id] | |
ax.scatter( | |
data["spike_times"][unit_id], | |
data["amplitudes"][unit_id][ch, :], | |
s=1, | |
color=color, | |
) | |
def _matrix_grid(ax, n_clusters, fig_scale): | |
for ndx in range(n_clusters - 1): | |
ax.axvline(x=ndx + 1, color="#FFFFFF", linewidth=fig_scale * 0.5) | |
ax.axhline(y=ndx + 1, color="#FFFFFF", linewidth=fig_scale * 0.5) | |
def _format_matrix_ax(ax, ticks, ticklabels, fig_scale, title): | |
ax.set_xticks(ticks) | |
ax.set_yticks(ticks) | |
ax.set_xticklabels(ticklabels) | |
ax.set_yticklabels(ticklabels) | |
ax.tick_params(length=0) | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
ax.set_title(title, fontsize=fig_scale * 12) | |
def get_above_threshold_matrix_indices( | |
cluster_data, sort_group_id, threshold_sets | |
): | |
return { | |
threshold_name: _apply_metric_matrix_thresholds( | |
cluster_data, sort_group_id, threshold_set.thresholds | |
) | |
for threshold_name, threshold_set in threshold_sets.items() | |
} | |
def _highlight_matrix_indices(valid_bool_map, threshold_sets, ax): | |
for threshold_name, valid_bool in valid_bool_map.items(): | |
ii, jj = np.where(valid_bool) | |
for ndx in range(np.sum(valid_bool)): | |
ax.add_patch( | |
matplotlib.patches.Rectangle( | |
(jj[ndx], ii[ndx]), | |
1, | |
1, | |
edgecolor=threshold_sets[threshold_name].color, | |
fill=False, | |
lw=threshold_sets[threshold_name].lw, | |
zorder=2 * len(valid_bool) ** 2, | |
clip_on=False, | |
) | |
) | |
def _get_metric_matrix( | |
cluster_data, | |
sort_group_id, | |
metric_name, | |
apply_upper_diagonal_mask=False, | |
mask_value=np.nan, | |
): | |
data = cluster_data["sort_groups"][sort_group_id] | |
metric_dict = data[metric_name] | |
index = metric_dict.keys() | |
matrix = np.array( | |
[[metric_dict[ii][jj] for jj in index] for ii in index] | |
).astype( | |
np.float | |
) # float so can mask with nan | |
# Mask upper diagonal if indicated | |
if apply_upper_diagonal_mask: | |
matrix = mask_upper_diagonal(matrix, mask_value=mask_value) | |
return pd.DataFrame(matrix, index=index, columns=index) | |
def _apply_metric_matrix_thresholds( | |
cluster_data, sort_group_id, threshold_objs | |
): | |
return np.prod( | |
[ | |
threshold_obj.threshold_direction( | |
_get_metric_matrix( | |
cluster_data, | |
sort_group_id, | |
threshold_obj.metric_name, | |
apply_upper_diagonal_mask=True, | |
mask_value=np.nan, | |
), | |
threshold_obj.threshold_value, | |
) | |
for threshold_obj in threshold_objs | |
], | |
axis=0, | |
) | |
def plot_amplitude_overlap_matrix( | |
cluster_data, | |
sort_group_id, | |
fig_scale=1, | |
fig_ax_list=None, | |
plot_color_bar=True, | |
): | |
data = cluster_data["sort_groups"][sort_group_id] | |
n_clusters = data["n_clusters"] | |
# Get amplitude overlap matrix | |
ao_matrix = _get_metric_matrix( | |
cluster_data, | |
sort_group_id, | |
"amplitude_overlaps", | |
apply_upper_diagonal_mask=True, | |
mask_value=0, | |
) | |
# Unpack figure and axis if passed | |
if fig_ax_list is not None: | |
fig, ax = fig_ax_list | |
# Otherwise make these | |
else: | |
fig = plt.figure(figsize=(n_clusters / 2, n_clusters / 2) * fig_scale) | |
gs = fig.add_gridspec(1, 1) | |
ax = fig.add_subplot(gs[0]) | |
pcm = plt.pcolormesh(ao_matrix, cmap="inferno", vmin=0, vmax=1) | |
_matrix_grid(ax, n_clusters, fig_scale) | |
label = "".join( | |
( | |
cluster_data["nwb_file_name"], | |
"\n", | |
"interval: ", | |
cluster_data["sort_interval_name"], | |
"\n", | |
f"sort group: {sort_group_id}", | |
"\n" "amplitude overlap", | |
) | |
) | |
_format_matrix_ax( | |
ax, | |
ticks=np.arange(0.5, n_clusters + 0.5), | |
ticklabels=ao_matrix.index, | |
fig_scale=fig_scale, | |
title=label, | |
) | |
# Color bar | |
if plot_color_bar: | |
add_colorbar(pcm, fig, ax) | |
return fig, ax | |
def plot_merge_matrices( | |
cluster_data, | |
sort_group_id, | |
valid_bool_map, | |
threshold_sets, | |
fig_scale=1, | |
plot_color_bar=True, | |
): | |
data = cluster_data["sort_groups"][sort_group_id] | |
n_clusters = data["n_clusters"] | |
# Get cosine similarity matrix | |
cs_matrix = _get_metric_matrix( | |
cluster_data, | |
sort_group_id, | |
"cosine_similarities", | |
apply_upper_diagonal_mask=True, | |
mask_value=0, | |
) | |
# Get correlogram asymmetry matrix | |
ca_matrix = _get_metric_matrix( | |
cluster_data, | |
sort_group_id, | |
"correlogram_asymmetries", | |
apply_upper_diagonal_mask=True, | |
mask_value=0, | |
) | |
# Initialize figure | |
num_columns = 3 | |
fig = plt.figure( | |
figsize=( | |
fig_scale * (num_columns * n_clusters / 2 + 2), | |
fig_scale * (n_clusters / 2), | |
) | |
) | |
width_ratios = [n_clusters / 2] * 3 | |
gs = fig.add_gridspec(1, num_columns, wspace=0.2, width_ratios=width_ratios) | |
# Ticks across plots | |
ticks = np.arange(0.5, n_clusters + 0.5) | |
# First subplot: cosine similarity | |
ax = fig.add_subplot(gs[0]) | |
pcm = plt.pcolormesh(cs_matrix, cmap="inferno", vmin=0, vmax=1) | |
# Grid | |
_matrix_grid(ax, n_clusters, fig_scale) | |
# Highlight indices crossing metric thresholds | |
_highlight_matrix_indices(valid_bool_map, threshold_sets, ax) | |
# Axis | |
_format_matrix_ax( | |
ax=ax, | |
ticks=ticks, | |
ticklabels=cs_matrix.index, | |
fig_scale=fig_scale, | |
title="cosine similarity", | |
) | |
# Color bar | |
if plot_color_bar: | |
add_colorbar(pcm, fig, ax) | |
# Second subplot: correlogram asymmetry | |
ax = fig.add_subplot(gs[1]) | |
pcm = plt.pcolormesh(ca_matrix, cmap="inferno", vmin=0.5, vmax=1) | |
# Grid | |
_matrix_grid(ax, n_clusters, fig_scale) | |
# Highlight indices crossing metric thresholds | |
_highlight_matrix_indices(valid_bool_map, threshold_sets, ax) | |
# Axis | |
_format_matrix_ax( | |
ax=ax, | |
ticks=ticks, | |
ticklabels=ca_matrix.index, | |
fig_scale=fig_scale, | |
title="correlogram asymmetry", | |
) | |
# Color bar | |
if plot_color_bar: | |
add_colorbar(pcm, fig, ax) | |
# Third subplot: amplitude overlap | |
ax = fig.add_subplot(gs[2]) | |
fig, ax = plot_amplitude_overlap_matrix( | |
cluster_data, | |
sort_group_id, | |
fig_scale=fig_scale, | |
fig_ax_list=[fig, ax], | |
plot_color_bar=plot_color_bar, | |
) | |
# Highlight indices crossing metric thresholds | |
_highlight_matrix_indices(valid_bool_map, threshold_sets, ax) | |
plt.show() | |
def plot_average_waveforms( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
amplitude_range=80, | |
trace_offset=40, | |
color="#2196F3", | |
title=None, | |
ax=None, | |
): | |
# Get inputs if not passed | |
if ax is None: | |
_, ax = plt.subplots() | |
if title is None: | |
title = f"{sort_group_id}_{unit_id}" | |
data = cluster_data["sort_groups"][sort_group_id] | |
n_channels = data["n_channels"] | |
n_points = data["waveform_window"].size | |
ax.axvline(x=n_points / 2, color="#9E9E9E", linewidth=1) | |
offset = np.tile(-np.arange(n_channels) * trace_offset, (n_points, 1)) | |
wv_avg = data["average_waveforms"][unit_id] | |
trace = wv_avg.T + offset | |
peak_ind = np.full(n_channels, False) | |
peak_ind[data["peak_channels"][unit_id]] = True | |
ax.plot(trace[:, ~peak_ind], color=color, linewidth=1, clip_on=False) | |
ax.plot(trace[:, peak_ind], color=color, linewidth=2.5, clip_on=False) | |
ax.set_xlim([0, n_points]) | |
ax.set_ylim( | |
[ | |
-2 * amplitude_range / 3 - (n_channels - 1) * trace_offset, | |
amplitude_range / 3, | |
] | |
) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax.set_title(title, fontsize=12) | |
def plot_average_waveforms_wrapper( | |
cluster_data, amplitude_range=80, trace_offset=40 | |
): | |
for sort_group_id, data in cluster_data["sort_groups"].items(): | |
n_clusters = data["n_clusters"] | |
fig = plt.figure(figsize=(n_clusters + 2, 1)) | |
width_ratios = np.ones(n_clusters + 1) | |
width_ratios[0] = 2 | |
gs = fig.add_gridspec( | |
1, n_clusters + 1, wspace=0.1, width_ratios=width_ratios | |
) | |
ax = fig.add_subplot(gs[0]) | |
label = "".join( | |
( | |
cluster_data["nwb_file_name"], | |
"\n", | |
"interval: ", | |
cluster_data["sort_interval_name"], | |
"\n", | |
f"sort group: {sort_group_id}", | |
) | |
) | |
ax.text(-0.3, 0.3, label, multialignment="left", fontsize=12) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
for ndx, unit_id in data["unit_ids"]: | |
ax = fig.add_subplot(gs[ndx + 1]) | |
plot_average_waveforms( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
amplitude_range=amplitude_range, | |
trace_offset=trace_offset, | |
ax=ax, | |
) | |
plt.show() | |
def plot_amplitude_distribution( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
ch=None, | |
max_amplitude=None, | |
amplitude_bin_size=2, | |
density=False, | |
histtype=None, | |
color="#2196F3", | |
label=None, | |
title=None, | |
remove_axes=False, | |
ax=None, | |
): | |
data = cluster_data["sort_groups"][sort_group_id] | |
# Define channel if not passed | |
if ch is None: | |
ch = data["peak_channels"][unit_id] | |
amp = data["amplitudes"][unit_id][ch, :] | |
# Get inputs if not passed | |
if ax is None: | |
_, ax = plt.subplots() | |
if max_amplitude is None: | |
max_amplitude = np.max(amp) | |
if title is None: | |
title = f"{sort_group_id}_{unit_id}" | |
bin_edges = np.arange( | |
0, max_amplitude + amplitude_bin_size, amplitude_bin_size | |
) | |
ax.hist( | |
amp, | |
bin_edges, | |
density=density, | |
histtype=histtype, | |
color=color, | |
label=label, | |
linewidth=4, | |
alpha=0.9, | |
) | |
# ax.set_xlim([bin_edges[0], bin_edges[-1]]) | |
if remove_axes: | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax.set_title(title, fontsize=12) | |
def plot_amplitude_distributions( | |
cluster_data, max_amplitude=50, amplitude_bin_size=2 | |
): | |
for sort_group_id, data in cluster_data["sort_groups"].items(): | |
n_clusters = data["n_clusters"] | |
fig = plt.figure(figsize=(n_clusters + 2, 1)) | |
width_ratios = np.ones(n_clusters + 1) | |
width_ratios[0] = 2 | |
gs = fig.add_gridspec( | |
1, n_clusters + 1, wspace=0.1, width_ratios=width_ratios | |
) | |
ax = fig.add_subplot(gs[0]) | |
label = "".join( | |
( | |
cluster_data["nwb_file_name"], | |
"\n", | |
"interval: ", | |
cluster_data["sort_interval_name"], | |
"\n", | |
f"sort group: {sort_group_id}", | |
) | |
) | |
ax.text(-0.3, 0.3, label, multialignment="center", fontsize=12) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
for ndx, (unit_id, time_diff) in enumerate(data["amplitudes"].items()): | |
ax = fig.add_subplot(gs[ndx + 1]) | |
plot_amplitude_distribution( | |
cluster_data, | |
sort_group_id, | |
unit_id, | |
max_amplitude=max_amplitude, | |
amplitude_bin_size=amplitude_bin_size, | |
ax=ax, | |
) | |
plt.show() | |
def plot_correlogram( | |
cluster_data, | |
sort_group_id, | |
cluster_1, | |
cluster_2, | |
max_time_difference=20 / 1000, | |
time_bin_size=1 / 1000, | |
color="#2196F3", | |
remove_axes=False, | |
ax=None, | |
): | |
# Get inputs if not passed | |
if ax is None: | |
_, ax = plt.subplots() | |
data = cluster_data["sort_groups"][sort_group_id] | |
time_diff = data["correlograms"][cluster_1][cluster_2] | |
bin_edges = np.arange( | |
-max_time_difference, max_time_difference + time_bin_size, time_bin_size | |
) | |
ax.hist(time_diff, bin_edges, color=color) | |
ax.set_xlim([bin_edges[0], bin_edges[-1]]) | |
ax.set_ylim([0, np.max(np.histogram(time_diff, bin_edges)[0])]) | |
if remove_axes: | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
burst_pair_amplitude_timing_bool = data[ | |
"burst_pair_amplitude_timing_bools" | |
][cluster_1][cluster_2] | |
correlogram_asymmetry = data["correlogram_asymmetries"][cluster_1][ | |
cluster_2 | |
] | |
isi_violation = data["unit_pair_percent_isi_violations"][cluster_1][ | |
cluster_2 | |
] | |
correlogram_count = int(data["correlogram_counts"][cluster_1][cluster_2]) | |
ax.set_title( | |
f"{sort_group_id}_{cluster_1} vs {sort_group_id}_{cluster_2}" | |
f"\ncount: {correlogram_count: .2f}" | |
f"\nasymmetry: {correlogram_asymmetry: .2f}" | |
f"\nISI violation: {isi_violation:.5f}" | |
f"\nburst_pair_amplitude_timing_bool: {burst_pair_amplitude_timing_bool}", | |
fontsize=12, | |
) | |
def plot_autocorrelograms( | |
cluster_data, max_time_difference=20 / 1000, time_bin_size=1 / 1000 | |
): | |
for sort_group_id, data in cluster_data["sort_groups"].items(): | |
n_clusters = data["n_clusters"] | |
fig = plt.figure(figsize=(n_clusters + 2, 1)) | |
width_ratios = np.ones(n_clusters + 1) | |
width_ratios[0] = 2 | |
gs = fig.add_gridspec( | |
1, n_clusters + 1, wspace=0.1, width_ratios=width_ratios | |
) | |
ax = fig.add_subplot(gs[0]) | |
label = "".join( | |
( | |
cluster_data["nwb_file_name"], | |
"\n", | |
"interval: ", | |
cluster_data["sort_interval_name"], | |
"\n", | |
f"sort group: {sort_group_id}", | |
) | |
) | |
ax.text(-0.3, 0.3, label, multialignment="center", fontsize=12) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
for ndx, cluster_num in enumerate(data["correlograms"].keys()): | |
ax = fig.add_subplot(gs[ndx + 1]) | |
plot_correlogram( | |
cluster_data, | |
sort_group_id, | |
cluster_1=cluster_num, | |
cluster_2=cluster_num, | |
max_time_difference=max_time_difference, | |
time_bin_size=time_bin_size, | |
ax=ax, | |
) | |
plt.show() | |
def plot_cosine_similarity_distribution(cluster_data, fig_scale=1): | |
n_clusters = np.array( | |
[data["n_clusters"] for data in cluster_data["sort_groups"].values()] | |
) | |
ind = [np.triu(np.full((count, count), True), 1) for count in n_clusters] | |
cs_list = [ | |
np.array( | |
[ | |
[ | |
data["cosine_similarities"][ii][jj] | |
for jj in data["cosine_similarities"].keys() | |
] | |
for ii in data["cosine_similarities"].keys() | |
] | |
) | |
for data in cluster_data["sort_groups"].values() | |
] | |
cs = [list(np.ravel(cs[ind[ndx]])[:]) for ndx, cs in enumerate(cs_list)] | |
cs = np.array(list(itertools.chain(*cs))) | |
fig = plt.figure(figsize=(fig_scale * 12, fig_scale * 3)) | |
gs = fig.add_gridspec(1, 1) | |
ax = fig.add_subplot(gs[0]) | |
label = "".join( | |
( | |
cluster_data["nwb_file_name"], | |
"\n", | |
"interval: ", | |
cluster_data["sort_interval_name"], | |
) | |
) | |
plt.hist(cs, 240, color="#2196F3") | |
ax.axvline(x=0, color="#424242", linewidth=fig_scale * 1.0) | |
# ax.axvline(x=0.9, color='#424242', linewidth=2) | |
ax.set_xlim([-1, 1]) | |
ax.set_xticks([-1, -0.5, 0, 0.5, 1]) | |
ax.set_xticklabels([-1, -0.5, 0, 0.5, 1], fontsize=fig_scale * 10) | |
ax.set_yticks([0, 20, 40, 60]) | |
ax.set_yticklabels([0, 20, 40, 60], fontsize=fig_scale * 10) | |
ax.set_xlabel("Cosine Similarity", fontsize=fig_scale * 12) | |
ax.set_ylabel("Count", fontsize=fig_scale * 12) | |
ax.set_title(label, fontsize=fig_scale * 12) | |
plt.show() | |
def plot_correlogram_asymmetry_distribution(cluster_data, fig_scale=1): | |
n_clusters = np.array( | |
[data["n_clusters"] for data in cluster_data["sort_groups"].values()] | |
) | |
ind = [np.triu(np.full((count, count), True), 1) for count in n_clusters] | |
ca_list = [ | |
np.array( | |
[ | |
[ | |
data["correlogram_asymmetries"][ii][jj] | |
for jj in data["correlogram_asymmetries"].keys() | |
] | |
for ii in data["correlogram_asymmetries"].keys() | |
] | |
) | |
for data in cluster_data["sort_groups"].values() | |
] | |
ca = [list(np.ravel(ca[ind[ndx]])[:]) for ndx, ca in enumerate(ca_list)] | |
ca = np.array(list(itertools.chain(*ca))) | |
fig = plt.figure(figsize=(fig_scale * 12, fig_scale * 3)) | |
gs = fig.add_gridspec(1, 1) | |
ax = fig.add_subplot(gs[0]) | |
label = "".join( | |
( | |
cluster_data["nwb_file_name"], | |
"\n", | |
"interval: ", | |
cluster_data["sort_interval_name"], | |
) | |
) | |
plt.hist(ca, 240, color="#2196F3") | |
ax.set_yscale("log") | |
ax.set_xlim([0.5, 1]) | |
ax.set_xticks([0.5, 0.6, 0.7, 0.8, 0.9, 1]) | |
ax.set_xticklabels([0.5, 0.6, 0.7, 0.8, 0.9, 1], fontsize=fig_scale * 10) | |
ax.set_yticks([1, 10, 100, 1000]) | |
ax.set_yticklabels([1, 10, 100, 1000], fontsize=fig_scale * 10) | |
ax.set_xlabel("Correlogram Asymmetry", fontsize=fig_scale * 12) | |
ax.set_ylabel("Count", fontsize=fig_scale * 12) | |
ax.set_title(label, fontsize=fig_scale * 12) | |
plt.show() | |
def check_all_unique(x): | |
if len(np.unique(x)) != len(x): | |
raise Exception(f"Not all elements unique") | |
def strip(x, strip_character, strip_start=False, strip_end=True): | |
if strip_start: | |
if x[0] == strip_character: | |
x = x[1:] | |
if strip_end: | |
if x[-1] == strip_character: | |
x = x[:-1] | |
return x | |
def df_filter_columns(df, key, column_and=True): | |
if column_and: | |
return df[ | |
np.asarray([df[k] == v for k, v in key.items()]).sum(axis=0) | |
== len(key) | |
] | |
else: | |
return df[ | |
np.asarray([df[k] == v for k, v in key.items()]).sum(axis=0) > 0 | |
] | |
def df_filter1_columns(df, key, tolerate_no_entry=False): | |
df_subset = df_filter_columns(df, key) | |
if np.logical_or( | |
len(df_subset) > 1, not tolerate_no_entry and len(df_subset) == 0 | |
): | |
raise Exception( | |
f"Should have found exactly one entry in df for key, but found {len(df_subset)}" | |
) | |
return df_subset | |
def df_pop(df, key, column, tolerate_no_entry=False): | |
df_subset = df_filter1_columns(df, key, tolerate_no_entry) | |
if len(df_subset) == 0: # empty df | |
return df_subset | |
return df_subset.iloc[0][column] | |
def df_filter_columns_isin(df, key): | |
if len(key) == 0: # if empty key | |
return df | |
return df[ | |
np.sum(np.asarray([df[k].isin(v) for k, v in key.items()]), axis=0) | |
== len(key) | |
] | |
# Alternate code: df[df[list(df_filter)].isin(df_filter).all(axis=1)] | |
def zip_df_columns(df, column_names=None): | |
if column_names is None: | |
column_names = df.columns | |
return zip(*[df[column_name] for column_name in column_names]) | |
def nwbf_name_from_subject_id_date(subject_id, date): | |
return f"{subject_id}{date}_.nwb" | |
def subject_id_date_from_nwbf_name(nwb_file_name): | |
len_date = 8 | |
subject_id_date = nwb_file_name.split("_.nwb")[0] | |
subject_id = subject_id_date[:-len_date] | |
date = subject_id_date[-len_date:] | |
return subject_id, date | |
def unpack_single_element(x, tolerate_no_entry=False, return_no_entry=None): | |
if tolerate_no_entry: | |
if len(x) == 0: | |
return return_no_entry | |
return unpack_single_element(x, tolerate_no_entry=False) | |
if len(x) != 1: | |
raise Exception(f"len should be one") | |
return x[0] | |
def mask_upper_diagonal(arr, mask_value=0): | |
mask = np.zeros_like(arr, dtype=np.bool) | |
mask[np.tril_indices_from(mask)] = True | |
arr[mask] = mask_value | |
return arr | |
def cd_make_if_nonexistent(directory): | |
""" | |
Change to a directory if it exists. If it does not, make it then change to it. | |
:param directory: string. Directory to change to. | |
""" | |
if not os.path.exists(directory): | |
print(f"Making directory: {directory}") | |
os.mkdir(directory) | |
print(f"Changing to directory: {directory}") | |
os.chdir(directory) | |
def single_axis(axes): | |
return hasattr(axes, "plot") | |
def get_ax_for_left_right_layout(axes, plot_num): | |
""" | |
Return ax from axes if arranging plots left to right, top to bottom | |
:param axes: array with axis objects | |
:param plot_num: plot number | |
:return: current axis given plot number | |
""" | |
if single_axis(axes): # single axis object | |
return axes | |
if len(np.shape(axes)) == 1: # one row or one column of subplots | |
return axes[plot_num] | |
elif len(np.shape(axes)) == 2: # 2D panel of subplots | |
num_columns = np.shape(axes)[1] | |
row, col = divmod( | |
plot_num, num_columns | |
) # find row/column for current plot | |
return axes[row, col] # get axis for current plot | |
else: | |
raise Exception(f"axes do not conform to expected cases") | |
def _load_notes( | |
subject_id, | |
spreadsheet_name, | |
recording_spreadsheet_path=None, | |
header=0, | |
tolerate_no_notes=False, | |
): | |
""" | |
Load spreadsheet for a given subject | |
""" | |
# Get inputs if not passed | |
if recording_spreadsheet_path is None: | |
recording_spreadsheet_path = get_recording_spreadsheet_path(subject_id) | |
# Get file path | |
file_path = os.path.join(recording_spreadsheet_path, spreadsheet_name) | |
# If tolerating no notes and no notes, return empty df | |
if tolerate_no_notes and not os.path.exists(file_path): | |
return pd.DataFrame() | |
# Return recording spreadsheet as pandas dataframe | |
return pd.read_csv(file_path, header=header) | |
def load_curation_merge_notes( | |
subject_id, date, recording_spreadsheet_path=None, tolerate_no_notes=False | |
): | |
""" | |
Load recording spreadsheet from saved file for a given subject | |
""" | |
return _load_notes( | |
subject_id, | |
spreadsheet_name=f"curation_merge - {subject_id}{date}_summary.csv", | |
recording_spreadsheet_path=recording_spreadsheet_path, | |
header=[0, 1], | |
tolerate_no_notes=tolerate_no_notes, | |
) | |
# - | |
# ## calculate metrics for all cell pairs | |
# + | |
# Define dataset - need to re-do this spike sorting | |
subject_ids = [ | |
"tonks", | |
"tonks", | |
"tonks", | |
"tonks", | |
"tonks", | |
"tonks", | |
] | |
dates = [ | |
"20211107", | |
"20211108", | |
"20211109", | |
"20211110", | |
"20211111", | |
"20211112", | |
] | |
# for nwb_file_name in ['tonks20211107_.nwb','tonks20211108_.nwb','tonks20211109_.nwb', | |
# 'tonks20211110_.nwb','tonks20211111_.nwb','tonks20211112_.nwb',]: | |
# subject_ids = ["ginny","ginny","ginny","ginny","ginny","ginny","ginny", | |
# "ginny","ginny","ginny","ginny","ginny",] | |
# dates = ["20211025","20211026","20211027","20211028","20211029","20211030","20211031","20211101", | |
# "20211102","20211103","20211104","20211105",] | |
# Make curation data | |
sort_interval_name = "r2_r3" | |
preproc_params_name = "franklab_tetrode_hippocampus" | |
sort_group_ids = all_tet_list # [0] # , 1, 2] # None | |
# this should be the curation_id correspodning to automatic curation | |
curation_id = 1 | |
overwrite_existing = False | |
verbose = True | |
# need to set make_data to true to generate new metrics | |
make_data = True | |
print("start", datetime.datetime.now()) | |
if make_data: | |
make_curation_data_wrapper( | |
subject_ids, | |
dates, | |
sort_interval_name=sort_interval_name, | |
preproc_params_name=preproc_params_name, | |
sort_group_ids=sort_group_ids, | |
curation_id=curation_id, | |
overwrite_existing=overwrite_existing, | |
verbose=verbose, | |
) | |
print("end", datetime.datetime.now()) | |
# - | |
# ## load calculated metrics and find candidates | |
# + | |
# try to combine all curation loading steps into one cell | |
# ginny | |
all_tet_list = ( | |
np.array( | |
[ | |
1, | |
2, | |
4, | |
5, | |
7, | |
8, | |
11, | |
12, | |
13, | |
14, | |
15, | |
16, | |
17, | |
20, | |
21, | |
22, | |
25, | |
26, | |
27, | |
28, | |
31, | |
33, | |
34, | |
35, | |
36, | |
37, | |
39, | |
41, | |
42, | |
43, | |
44, | |
45, | |
47, | |
49, | |
51, | |
52, | |
54, | |
56, | |
57, | |
59, | |
61, | |
62, | |
63, | |
64, | |
] | |
) | |
- 1 | |
) | |
# Load curation data | |
subject_ids = [ | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
"ginny", | |
] | |
dates = [ | |
"20211023", | |
"20211024", | |
"20211025", | |
"20211026", | |
"20211027", | |
"20211028", | |
"20211029", | |
"20211030", | |
"20211031", | |
"20211101", | |
"20211102", | |
"20211103", | |
"20211104", | |
"20211105", | |
"20211106", | |
"20211108", | |
"20211109", | |
"20211110", | |
"20211111", | |
"20211112", | |
"20211113", | |
] | |
for date in dates: | |
subject_ids = ["ginny"] | |
dates = [date] | |
# set file name | |
def make_param_name(param_values): | |
return "_2".join([str(x) for x in param_values]) | |
# Make curation data | |
sort_interval_name = "r2_r3" | |
preproc_params_name = "franklab_tetrode_hippocampus" | |
sort_group_ids = all_tet_list # [0] # , 1, 2] # None | |
# this should be the curation_id correspodning to automatic curation | |
curation_id = 1 | |
overwrite_quantities = False | |
use_old = False | |
target_region = None # only use if have saved out data with target region name using the code commented out below | |
use_target_region = False | |
verbose = True | |
file_path_base = "/cumulus/mcoulter/curation_data/" # PUT PATH WHERE YOU WANT TO SAVE FILES HERE | |
cluster_data_container = dict() | |
for subject_id, date in zip(subject_ids, dates): | |
# Define directory to save data in | |
save_dir = f"{file_path_base}/{subject_id}/" + "old" * use_old | |
# Get nwb file name | |
nwb_file_name = nwbf_name_from_subject_id_date(subject_id, date) | |
cluster_data_container[nwb_file_name] = load_curation_data( | |
save_dir=save_dir, | |
nwb_file_name=nwb_file_name, | |
sort_interval_name=sort_interval_name, | |
preproc_params_name=preproc_params_name, | |
sort_group_ids=sort_group_ids, | |
target_region=target_region, | |
curation_id=curation_id, | |
overwrite_quantities=overwrite_quantities, | |
verbose=verbose, | |
) | |
# Get curation merge notes | |
curation_merge_notes_map = { | |
nwbf_name_from_subject_id_date( | |
subject_id, date | |
): get_curation_spreadsheet(subject_id, date) | |
for subject_id, date in zip(subject_ids, dates) | |
} | |
# Make df with metrics and merge label for all units | |
# Note that this should be done separately for merge types (e.g. burst pair and split cell), | |
# so that each cell falls into only one category: "true" merge pair, "false" merge pair, or unlabeled | |
# Get metrics for these merge pairs | |
merge_types = ["burst_pair", "split_cell"] | |
merge_pair_identifiers = ["sort_group_id", "unit_id_1", "unit_id_2"] | |
labels = [True, False] | |
merge_type_metric_names_map = { | |
"burst_pair": [ | |
"cosine_similarities", | |
"correlogram_asymmetries", | |
"unit_pair_percent_isi_violations", | |
"correlogram_isi_violation_ratios", | |
"correlogram_counts", | |
"burst_pair_amplitude_timing_bools", | |
"unit_merge_valid_lower_amplitude_fractions", | |
"unit_merge_amplitude_decrements_0.015", | |
"unit_merge_amplitude_decrements_0.4", | |
"amplitude_decrement_changes_0.015", | |
"amplitude_decrement_changes_0.4", | |
], | |
"split_cell": [ | |
"cosine_similarities", | |
"amplitude_overlaps", | |
"unit_pair_percent_isi_violations", | |
"correlogram_isi_violation_ratios", | |
"correlogram_counts", | |
], | |
} | |
column_names = [ | |
"nwb_file_name", | |
"metric_name", | |
"label", | |
"merge_tuple", | |
"metric_value", | |
] | |
# Get metrics for these nwb file names | |
nwb_file_names = list(cluster_data_container.keys()) | |
# Check that nwb file names well defined | |
if not all([x in cluster_data_container.keys() for x in nwb_file_names]): | |
raise Exception( | |
f"Can only make metrics for nwb file names in cluster_data_container, which are: {cluster_data_container.keys()}" | |
) | |
# Get units with "true" and "false" human labels | |
merge_df_map = { | |
merge_type: pd.DataFrame(columns=column_names) | |
for merge_type in merge_types | |
} # initialize | |
for merge_type_idx, merge_type in enumerate(merge_types): | |
metric_names = merge_type_metric_names_map[merge_type] | |
data_list = [] | |
for nwb_file_name in nwb_file_names: | |
cluster_data = cluster_data_container[nwb_file_name] # cluster data | |
curation_merge_notes = curation_merge_notes_map[ | |
nwb_file_name | |
] # notes with "true" and "false" human labels | |
# Continue if curation_merge_notes empty | |
if len(curation_merge_notes) == 0: | |
continue | |
# Otherwise, get metrics for units labeled as "true" or "false" merge pairs from curation_merge_notes | |
for metric_name in metric_names: | |
for label in labels: | |
# Define merge tuples as those present in curation notes | |
merge_tuples = list( | |
zip_df_columns( | |
df_filter_columns( | |
curation_merge_notes_map[nwb_file_name], | |
{"merge_type": merge_type, "label": label}, | |
), | |
merge_pair_identifiers, | |
) | |
) | |
# Remove merge tuples with sort group not in cluster data | |
merge_tuples = [ | |
x | |
for x in merge_tuples | |
if x[0] in cluster_data["sort_groups"] | |
] | |
for merge_tuple in merge_tuples: | |
sort_group_id, unit_1, unit_2 = merge_tuple | |
# Continue if metric not calculated for current sort group | |
if ( | |
metric_name | |
not in cluster_data["sort_groups"][sort_group_id] | |
): | |
continue | |
metric_value = cluster_data["sort_groups"][ | |
sort_group_id | |
][metric_name][unit_1][unit_2] | |
data_list.append( | |
( | |
nwb_file_name, | |
metric_name, | |
label, | |
merge_tuple, | |
metric_value, | |
) | |
) | |
# Only update if data, otherwise overwrites column names with nothing | |
if len(data_list) > 0: | |
merge_df_map[merge_type] = pd.DataFrame.from_dict( | |
{k: v for k, v in zip(column_names, list(zip(*data_list)))} | |
) | |
# Get unlabeled units | |
for merge_type, merge_df in merge_df_map.items(): | |
metric_names = merge_type_metric_names_map[merge_type] | |
data_list = [] | |
for nwb_file_name in nwb_file_names: | |
# ...Get all unit pairs for this nwb file (limited to cluster data files that have been created) | |
sort_group_id_unit_pair_map = { | |
sort_group_id: list( | |
itertools.combinations(data["unit_ids"], r=2) | |
) | |
for sort_group_id, data in cluster_data["sort_groups"].items() | |
} | |
merge_tuples = [ | |
(sort_group_id, unit_1, unit_2) | |
for sort_group_id, unit_pairs in sort_group_id_unit_pair_map.items() | |
for (unit_1, unit_2) in unit_pairs | |
] | |
# ...Get labeled units for this nwb file | |
merge_df_subset = df_filter_columns( | |
merge_df, {"nwb_file_name": nwb_file_name} | |
) | |
labeled_merge_tuples = merge_df_subset["merge_tuple"] | |
merge_tuples = set(merge_tuples) - set( | |
labeled_merge_tuples | |
) # unit pairs that were not labeled as true or false merge pairs | |
for merge_tuple in merge_tuples: | |
sort_group_id, unit_1, unit_2 = merge_tuple | |
for metric_name in metric_names: | |
# Continue if metric not calculated for current sort group | |
if ( | |
metric_name | |
not in cluster_data["sort_groups"][sort_group_id] | |
): | |
continue | |
metric_value = cluster_data["sort_groups"][sort_group_id][ | |
metric_name | |
][unit_1][unit_2] | |
data_list.append( | |
( | |
nwb_file_name, | |
metric_name, | |
"none", | |
merge_tuple, | |
metric_value, | |
) | |
) | |
# Update merge df | |
other_df = pd.DataFrame.from_dict( | |
{k: v for k, v in zip(column_names, list(zip(*data_list)))} | |
) | |
merge_df_map[merge_type] = pd.concat((merge_df, other_df)) | |
# re-run with new thresholds | |
# decided to raise amp overlap to 0.5 after checking a few days | |
Threshold = namedtuple( | |
"threshold", "metric_name threshold_value threshold_direction" | |
) | |
Thresholds = namedtuple("thresholds", "name thresholds color lw") | |
threshold_sets = [ | |
( | |
"burst_pair", | |
[ | |
("cosine_similarities", 0.7, operator.gt), | |
("correlogram_asymmetries", 0.6, operator.gt), | |
("correlogram_counts", 100, operator.gt), | |
("unit_pair_percent_isi_violations", 0.25, operator.lt), | |
("burst_pair_amplitude_timing_bools", 0, operator.gt), | |
], | |
"#2196F3", | |
9, | |
), | |
( | |
"split_cell", | |
[ | |
("cosine_similarities", 0.5, operator.gt), | |
("correlogram_counts", 100, operator.gt), | |
("unit_pair_percent_isi_violations", 0.25, operator.lt), | |
("amplitude_overlaps", 0.5, operator.gt), | |
], | |
"limegreen", | |
3, | |
), | |
# ("test", | |
# [("cosine_similarities", 0, operator.gt)], | |
# "orange", | |
# 5) | |
] | |
threshold_sets = { | |
name: Thresholds(name, [Threshold(*x) for x in thresholds], color, lw) | |
for name, thresholds, color, lw in threshold_sets | |
} | |
# print merge candidates | |
merge_count = 0 | |
cluster_data = cluster_data_container[nwb_file_name] | |
tuple_list = [] | |
for sort_group_id, data in cluster_data["sort_groups"].items(): | |
# print('merge candidate',sort_group_id,'_',unit_1,'_vs._', | |
# sort_group_id,'_',unit_2,'_') | |
valid_bool_map = get_above_threshold_matrix_indices( | |
cluster_data, sort_group_id, threshold_sets | |
) | |
for ( | |
threshold_name, | |
valid_bool, | |
) in valid_bool_map.items(): # threshold sets | |
# Find indices in array corresponding to merge candidates | |
merge_candidate_idxs = list(zip(*np.where(valid_bool))) | |
# Convert merge candidate indices in array to unit IDs | |
merge_candidates = [ | |
tuple(np.asarray(data["unit_ids"])[np.asarray(idxs)]) | |
for idxs in merge_candidate_idxs | |
] | |
# Loop through merge candidates and plot metrics | |
for unit_1, unit_2 in merge_candidates: # units | |
# print(nwb_file_name,': merge candidate for tetrode',sort_group_id,', clusters',unit_1,'and', | |
# unit_2,'are a',threshold_name, '(',sort_group_id,unit_1,unit_2,')') | |
# print(tuple([sort_group_id,unit_1,unit_2])) | |
# create list of merge candidates | |
tuple_list.append(tuple([sort_group_id, unit_1, unit_2])) | |
merge_count += 1 | |
print("merge count", merge_count) | |
print([dates]) | |
print(len(tuple_list)) | |
print(tuple_list) | |
# - |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment