Created
March 12, 2024 23:45
-
-
Save matham/22fd8088dede335135f0cd8b23f943ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h5py | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import re | |
from elephant.statistics import time_histogram | |
import quantities | |
from neo import SpikeTrain | |
def arr2str(arr): | |
arr = np.asarray(arr).squeeze() | |
indices, = np.nonzero(arr == 0) | |
n = len(arr) | |
if len(indices): | |
n = indices[0] | |
return "".join(map(chr, arr[:n])).strip() | |
class MatChannel: | |
channel_num: int = 0 | |
comment: str = "" | |
title: str = "" | |
def __init__(self, channel: h5py.Group): | |
super().__init__() | |
m = re.match(".+_Ch(\\d+)$", channel.name) | |
assert m is not None, "Pattern name_ChN not found" | |
self.channel_num = int(m.group(1)) | |
self.comment = arr2str(channel["comment"]) | |
self.title = arr2str(channel["title"]) | |
def __str__(self): | |
return f"<#{self.channel_num} {self.title} {self.__class__}@{id(self)}>" | |
def __repr__(self): | |
return self.__str__() | |
class MatMarker(MatChannel): | |
times: np.ndarray = None | |
codes: np.ndarray = None | |
length: int = 0 | |
resolution: float = 0 | |
def __init__(self, channel: h5py.Group): | |
super().__init__(channel=channel) | |
self.codes = np.asarray(channel["codes"][0, :]) | |
self.times = np.asarray(channel['times']).squeeze() | |
self.resolution = channel['resolution'][0, 0] | |
self.length = int(channel['length'][0, 0]) | |
assert len(self.times) == self.length | |
assert len(self.codes) == self.length | |
class MatTextMarker(MatMarker): | |
items: int = 0 | |
text: list[str] = None | |
def __init__(self, channel: h5py.Group): | |
super().__init__(channel=channel) | |
self.items = int(channel['items'][0, 0]) | |
values = np.asarray(channel['text']) | |
self.text = [arr2str(values[:, i]) for i in range(values.shape[1])] | |
assert len(self.text) == self.length | |
class MatAnalogDataBase(MatChannel): | |
interval: float = 0 | |
length: int = 0 | |
offset: float = 0 | |
scale: float = 0 | |
units: str = "" | |
times_raw: h5py.Dataset = None | |
values_raw: h5py.Dataset = None | |
_times: np.ndarray = None | |
@property | |
def times(self) -> np.ndarray: | |
times = self._times | |
if times is None: | |
self._times = np.asarray(self.times_raw).squeeze() | |
times = self._times | |
return times | |
_values: np.ndarray = None | |
@property | |
def values(self) -> np.ndarray: | |
values = self._values | |
if values is None: | |
self._values = np.asarray(self.values_raw) | |
values = self._values | |
return values | |
def __init__(self, channel: h5py.Group): | |
super().__init__(channel=channel) | |
self.interval = channel['interval'][0, 0] | |
self.length = int(channel['length'][0, 0]) | |
self.offset = channel['offset'][0, 0] | |
self.scale = channel['scale'][0, 0] | |
self.units = arr2str(channel["units"]) | |
self.times_raw = channel['times'] | |
self.values_raw = channel["values"] | |
assert self.times_raw.shape[1] == self.length | |
assert self.values_raw.shape[1] == self.length | |
class MatAnalogData(MatAnalogDataBase): | |
start: float = 0 | |
def __init__(self, channel: h5py.Group): | |
super().__init__(channel=channel) | |
self.start = channel['start'][0, 0] | |
class MatSpikeTrain(MatAnalogDataBase): | |
items: int = 0 | |
resolution: float = 0 | |
traces: int = 0 | |
trigger: float = 0 | |
codes: np.ndarray = None | |
def __init__(self, channel: h5py.Group): | |
super().__init__(channel=channel) | |
self.codes = np.asarray(channel["codes"][0, :]) | |
self.items = int(channel['items'][0, 0]) | |
self.resolution = channel['resolution'][0, 0] | |
self.traces = int(channel['traces'][0, 0]) | |
self.trigger = channel['trigger'][0, 0] | |
assert len(self.codes) == self.length | |
assert self.values_raw.shape[0] == self.items | |
def get_times_of_cell(self, cell_or_cells: list[int] | int) -> np.ndarray: | |
codes = self.codes | |
try: | |
idx = np.zeros(len(codes), dtype=np.bool_) | |
for cell in cell_or_cells: | |
idx += codes == cell | |
except TypeError: | |
idx = codes == cell_or_cells | |
return self.times[idx] | |
def get_cell_ids(self) -> set[int]: | |
return set(self.codes) | |
class Spike2Mat: | |
analog_data: list[MatAnalogData] = None | |
trains: list[MatSpikeTrain] = None | |
markers: list[MatMarker] = None | |
text_markers: list[MatTextMarker] = None | |
def __init__(self): | |
self.analog_data = [] | |
self.trains = [] | |
self.markers = [] | |
self.text_markers = [] | |
def parse_file(self, filename): | |
f = h5py.File(filename, 'r') | |
for chan_name in f.keys(): | |
if chan_name == "file": | |
continue | |
channel = f[chan_name] | |
if "interval" in channel: | |
if "traces" in channel: | |
self.trains.append(MatSpikeTrain(channel=channel)) | |
else: | |
self.analog_data.append(MatAnalogData(channel=channel)) | |
else: | |
if "text" in channel: | |
self.text_markers.append(MatTextMarker(channel=channel)) | |
else: | |
self.markers.append(MatMarker(channel=channel)) | |
def get_train_cell_times( | |
self, channel_and_cells: dict[int, int | None | list[int | list[int]]] | |
) -> list[tuple[MatSpikeTrain, int | list[int], np.ndarray]]: | |
trains = [] | |
for train in self.trains: | |
num = train.channel_num | |
if num not in channel_and_cells: | |
continue | |
cells = channel_and_cells[num] | |
if isinstance(cells, int): | |
cells = [cells] | |
elif cells is None: | |
cells = train.get_cell_ids() | |
for cell_or_cells in cells: | |
trains.append((train, cell_or_cells, train.get_times_of_cell(cell_or_cells))) | |
return trains | |
def get_event_interval_times( | |
self, pre_duration, post_duration, marker_channel_num=32, label_channel_num=30, start_code=79, | |
end_code=111 | |
) -> list[tuple[str, float, float]]: | |
mark_channel = None | |
label_channel = None | |
for c in self.markers: | |
if c.channel_num == marker_channel_num: | |
mark_channel = c | |
break | |
for c in self.text_markers: | |
if c.channel_num == label_channel_num: | |
label_channel = c | |
break | |
if mark_channel is None: | |
raise ValueError(f"Didn't find channel {marker_channel_num}") | |
if label_channel is None: | |
raise ValueError(f"Didn't find channel {label_channel_num}") | |
start = mark_channel.codes == start_code | |
end = mark_channel.codes == end_code | |
if sum(start) != sum(end): | |
raise ValueError(f"Number of trial start/stop codes don't match ({sum(start)}, {sum(end)})") | |
if sum(start) != label_channel.length: | |
raise ValueError( | |
f"Time and label marker channels are different lengths ({mark_channel.length // 4}, " | |
f"{label_channel.length})") | |
start_ts = mark_channel.times[start] - pre_duration | |
end_ts = mark_channel.times[end] + post_duration | |
return list(zip(label_channel.text, start_ts, end_ts)) | |
def dump_spike_bins_csv( | |
self, filename: str, channel_and_cells: dict[int, int | None | list[int | list[int]]], pre_duration: float, | |
post_duration: float, bin_duration: float, marker_channel_num=32, label_channel_num=30, start_code=79, | |
end_code=111 | |
): | |
trains = self.get_train_cell_times(channel_and_cells) | |
trials = self.get_event_interval_times( | |
pre_duration, post_duration, marker_channel_num, label_channel_num, start_code, end_code) | |
ts_min = min( | |
min(times[0] for _, _, times in trains), | |
min(ts for _, ts, _ in trials) | |
) * quantities.s | |
ts_max = max( | |
max(times[-1] for _, _, times in trains), | |
max(te for _, _, te in trials) | |
) * quantities.s | |
results = [] | |
for train, cell_or_cells, times in trains: | |
neo_train = SpikeTrain(times=times, t_stop=ts_max, units=quantities.s, t_start=ts_min) | |
for name, ts, te in trials: | |
hist = time_histogram( | |
[neo_train], bin_size=bin_duration * quantities.s, t_start=ts * quantities.s, | |
t_stop=te * quantities.s, output='rate' | |
) | |
results.append((name, ts, te, train, cell_or_cells, hist)) | |
largest_bins = max([item[-1] for item in results], key=lambda x: x.shape[0]) | |
bin_times = largest_bins.times.rescale(quantities.s).magnitude | |
bin_times -= bin_times[0] | |
with open(filename, "w") as fh: | |
fh.write("Odor,Start time,End time,Channel,Cell,title,") | |
fh.write(",".join(map(str, bin_times))) | |
fh.write("\n") | |
for name, ts, te, train, cell_or_cells, times in results: | |
try: | |
cell = "-".join(map(str, cell_or_cells)) | |
except TypeError: | |
cell = f"{cell_or_cells}" | |
items = [name, ts, te, train.channel_num, cell, train.title] | |
fh.write(",".join(map(str, items))) | |
fh.write(",") | |
data = times.as_array(1 / quantities.s).squeeze() | |
data = np.concatenate([np.array(data), np.zeros(len(bin_times) - len(data))]) | |
fh.write(",".join(map(str, data))) | |
fh.write("\n") | |
if __name__ == "__main__": | |
filename = r"C:\Users\Matthew Einhorn\Downloads\r34.mat" | |
output_csv = r"C:\Users\Matthew Einhorn\Downloads\r34.csv" | |
mat = Spike2Mat() | |
mat.parse_file(filename) | |
for train in sorted(mat.trains,key=lambda t: t.channel_num): | |
print(f"Channel={train.channel_num}, name=\"{train.title}\", cells={train.get_cell_ids()}") | |
# mat.dump_spike_bins_csv( | |
# filename=output_csv, | |
# channel_and_cells={ | |
# 2: None, | |
# 8: None, | |
# 10: None, | |
# 11: None, | |
# 12: None, | |
# 13: None, | |
# 14: None, | |
# 15: None, | |
# 16: None, | |
# 6001: None, | |
# }, | |
# pre_duration=4, post_duration=8, bin_duration=.02 | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment