Skip to content

Instantly share code, notes, and snippets.

@matham
Created March 12, 2024 23:45
Show Gist options
  • Save matham/22fd8088dede335135f0cd8b23f943ac to your computer and use it in GitHub Desktop.
Save matham/22fd8088dede335135f0cd8b23f943ac to your computer and use it in GitHub Desktop.
import h5py
import matplotlib.pyplot as plt
import numpy as np
import re
from elephant.statistics import time_histogram
import quantities
from neo import SpikeTrain
def arr2str(arr):
arr = np.asarray(arr).squeeze()
indices, = np.nonzero(arr == 0)
n = len(arr)
if len(indices):
n = indices[0]
return "".join(map(chr, arr[:n])).strip()
class MatChannel:
channel_num: int = 0
comment: str = ""
title: str = ""
def __init__(self, channel: h5py.Group):
super().__init__()
m = re.match(".+_Ch(\\d+)$", channel.name)
assert m is not None, "Pattern name_ChN not found"
self.channel_num = int(m.group(1))
self.comment = arr2str(channel["comment"])
self.title = arr2str(channel["title"])
def __str__(self):
return f"<#{self.channel_num} {self.title} {self.__class__}@{id(self)}>"
def __repr__(self):
return self.__str__()
class MatMarker(MatChannel):
times: np.ndarray = None
codes: np.ndarray = None
length: int = 0
resolution: float = 0
def __init__(self, channel: h5py.Group):
super().__init__(channel=channel)
self.codes = np.asarray(channel["codes"][0, :])
self.times = np.asarray(channel['times']).squeeze()
self.resolution = channel['resolution'][0, 0]
self.length = int(channel['length'][0, 0])
assert len(self.times) == self.length
assert len(self.codes) == self.length
class MatTextMarker(MatMarker):
items: int = 0
text: list[str] = None
def __init__(self, channel: h5py.Group):
super().__init__(channel=channel)
self.items = int(channel['items'][0, 0])
values = np.asarray(channel['text'])
self.text = [arr2str(values[:, i]) for i in range(values.shape[1])]
assert len(self.text) == self.length
class MatAnalogDataBase(MatChannel):
interval: float = 0
length: int = 0
offset: float = 0
scale: float = 0
units: str = ""
times_raw: h5py.Dataset = None
values_raw: h5py.Dataset = None
_times: np.ndarray = None
@property
def times(self) -> np.ndarray:
times = self._times
if times is None:
self._times = np.asarray(self.times_raw).squeeze()
times = self._times
return times
_values: np.ndarray = None
@property
def values(self) -> np.ndarray:
values = self._values
if values is None:
self._values = np.asarray(self.values_raw)
values = self._values
return values
def __init__(self, channel: h5py.Group):
super().__init__(channel=channel)
self.interval = channel['interval'][0, 0]
self.length = int(channel['length'][0, 0])
self.offset = channel['offset'][0, 0]
self.scale = channel['scale'][0, 0]
self.units = arr2str(channel["units"])
self.times_raw = channel['times']
self.values_raw = channel["values"]
assert self.times_raw.shape[1] == self.length
assert self.values_raw.shape[1] == self.length
class MatAnalogData(MatAnalogDataBase):
start: float = 0
def __init__(self, channel: h5py.Group):
super().__init__(channel=channel)
self.start = channel['start'][0, 0]
class MatSpikeTrain(MatAnalogDataBase):
items: int = 0
resolution: float = 0
traces: int = 0
trigger: float = 0
codes: np.ndarray = None
def __init__(self, channel: h5py.Group):
super().__init__(channel=channel)
self.codes = np.asarray(channel["codes"][0, :])
self.items = int(channel['items'][0, 0])
self.resolution = channel['resolution'][0, 0]
self.traces = int(channel['traces'][0, 0])
self.trigger = channel['trigger'][0, 0]
assert len(self.codes) == self.length
assert self.values_raw.shape[0] == self.items
def get_times_of_cell(self, cell_or_cells: list[int] | int) -> np.ndarray:
codes = self.codes
try:
idx = np.zeros(len(codes), dtype=np.bool_)
for cell in cell_or_cells:
idx += codes == cell
except TypeError:
idx = codes == cell_or_cells
return self.times[idx]
def get_cell_ids(self) -> set[int]:
return set(self.codes)
class Spike2Mat:
analog_data: list[MatAnalogData] = None
trains: list[MatSpikeTrain] = None
markers: list[MatMarker] = None
text_markers: list[MatTextMarker] = None
def __init__(self):
self.analog_data = []
self.trains = []
self.markers = []
self.text_markers = []
def parse_file(self, filename):
f = h5py.File(filename, 'r')
for chan_name in f.keys():
if chan_name == "file":
continue
channel = f[chan_name]
if "interval" in channel:
if "traces" in channel:
self.trains.append(MatSpikeTrain(channel=channel))
else:
self.analog_data.append(MatAnalogData(channel=channel))
else:
if "text" in channel:
self.text_markers.append(MatTextMarker(channel=channel))
else:
self.markers.append(MatMarker(channel=channel))
def get_train_cell_times(
self, channel_and_cells: dict[int, int | None | list[int | list[int]]]
) -> list[tuple[MatSpikeTrain, int | list[int], np.ndarray]]:
trains = []
for train in self.trains:
num = train.channel_num
if num not in channel_and_cells:
continue
cells = channel_and_cells[num]
if isinstance(cells, int):
cells = [cells]
elif cells is None:
cells = train.get_cell_ids()
for cell_or_cells in cells:
trains.append((train, cell_or_cells, train.get_times_of_cell(cell_or_cells)))
return trains
def get_event_interval_times(
self, pre_duration, post_duration, marker_channel_num=32, label_channel_num=30, start_code=79,
end_code=111
) -> list[tuple[str, float, float]]:
mark_channel = None
label_channel = None
for c in self.markers:
if c.channel_num == marker_channel_num:
mark_channel = c
break
for c in self.text_markers:
if c.channel_num == label_channel_num:
label_channel = c
break
if mark_channel is None:
raise ValueError(f"Didn't find channel {marker_channel_num}")
if label_channel is None:
raise ValueError(f"Didn't find channel {label_channel_num}")
start = mark_channel.codes == start_code
end = mark_channel.codes == end_code
if sum(start) != sum(end):
raise ValueError(f"Number of trial start/stop codes don't match ({sum(start)}, {sum(end)})")
if sum(start) != label_channel.length:
raise ValueError(
f"Time and label marker channels are different lengths ({mark_channel.length // 4}, "
f"{label_channel.length})")
start_ts = mark_channel.times[start] - pre_duration
end_ts = mark_channel.times[end] + post_duration
return list(zip(label_channel.text, start_ts, end_ts))
def dump_spike_bins_csv(
self, filename: str, channel_and_cells: dict[int, int | None | list[int | list[int]]], pre_duration: float,
post_duration: float, bin_duration: float, marker_channel_num=32, label_channel_num=30, start_code=79,
end_code=111
):
trains = self.get_train_cell_times(channel_and_cells)
trials = self.get_event_interval_times(
pre_duration, post_duration, marker_channel_num, label_channel_num, start_code, end_code)
ts_min = min(
min(times[0] for _, _, times in trains),
min(ts for _, ts, _ in trials)
) * quantities.s
ts_max = max(
max(times[-1] for _, _, times in trains),
max(te for _, _, te in trials)
) * quantities.s
results = []
for train, cell_or_cells, times in trains:
neo_train = SpikeTrain(times=times, t_stop=ts_max, units=quantities.s, t_start=ts_min)
for name, ts, te in trials:
hist = time_histogram(
[neo_train], bin_size=bin_duration * quantities.s, t_start=ts * quantities.s,
t_stop=te * quantities.s, output='rate'
)
results.append((name, ts, te, train, cell_or_cells, hist))
largest_bins = max([item[-1] for item in results], key=lambda x: x.shape[0])
bin_times = largest_bins.times.rescale(quantities.s).magnitude
bin_times -= bin_times[0]
with open(filename, "w") as fh:
fh.write("Odor,Start time,End time,Channel,Cell,title,")
fh.write(",".join(map(str, bin_times)))
fh.write("\n")
for name, ts, te, train, cell_or_cells, times in results:
try:
cell = "-".join(map(str, cell_or_cells))
except TypeError:
cell = f"{cell_or_cells}"
items = [name, ts, te, train.channel_num, cell, train.title]
fh.write(",".join(map(str, items)))
fh.write(",")
data = times.as_array(1 / quantities.s).squeeze()
data = np.concatenate([np.array(data), np.zeros(len(bin_times) - len(data))])
fh.write(",".join(map(str, data)))
fh.write("\n")
if __name__ == "__main__":
filename = r"C:\Users\Matthew Einhorn\Downloads\r34.mat"
output_csv = r"C:\Users\Matthew Einhorn\Downloads\r34.csv"
mat = Spike2Mat()
mat.parse_file(filename)
for train in sorted(mat.trains,key=lambda t: t.channel_num):
print(f"Channel={train.channel_num}, name=\"{train.title}\", cells={train.get_cell_ids()}")
# mat.dump_spike_bins_csv(
# filename=output_csv,
# channel_and_cells={
# 2: None,
# 8: None,
# 10: None,
# 11: None,
# 12: None,
# 13: None,
# 14: None,
# 15: None,
# 16: None,
# 6001: None,
# },
# pre_duration=4, post_duration=8, bin_duration=.02
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment