Last active
June 2, 2025 22:06
-
-
Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Export results for teaball experiments - sniffing, occupancy etc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass, field | |
import pprint | |
from typing import Callable, Union, Literal, Any | |
from pathlib import Path | |
import csv | |
from scipy.signal import decimate | |
from copy import deepcopy | |
import math | |
import tqdm | |
import numpy as np | |
from collections import defaultdict | |
from functools import partial | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from scipy.ndimage import gaussian_filter | |
import json | |
MEASURE_TYPE = Literal["occupancy", "motion_index_freezing", "speed_freezing"] | |
def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = "", width_inch: int = 8, height_inch: int = 6): | |
if save_fig_root: | |
save_fig_root.mkdir(parents=True, exist_ok=True) | |
fig = plt.gcf() | |
fig.set_size_inches(width_inch, height_inch) | |
fig.tight_layout() | |
fig.savefig( | |
save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight', | |
dpi=300 | |
) | |
plt.close() | |
else: | |
plt.tight_layout() | |
plt.show() | |
@dataclass | |
class Experiment: | |
pre_start: float | |
pre_end: float | |
trial_start: float | |
trial_end: float | |
post_start: float | |
post_end: float | |
filename_fmt: str | |
title_root: str | |
metadata: dict[str, Any] = field(default_factory=dict) | |
triplet_name: tuple[str, str, str] = "Habituation", "Trial", "Post Trial" | |
pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
pre_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
trial_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
post_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
chopped_pos_data: dict[str, "SubjectPos"] = field(default_factory=dict, init=False, repr=False) | |
box_center: tuple[int, int] = field(default=None, init=False) | |
box_size: tuple[int, int] = field(default=None, init=False) | |
image_size: tuple[int, int] = field(default=None, init=False) | |
image_offset: tuple[int, int] = field(default=None, init=False) | |
# how much we need to offset our box center so our box center aligns with the mean box center of all boxes | |
enlarged_image_size: tuple[int, int] = field(default=None, init=False, repr=False) | |
# size of the largest image/canvas we need to fit all experiments, including their offsets | |
largest_box_size: tuple[int, int] = field(default=None, init=False, repr=False) | |
# size of the largest box of all the experiments | |
def get_header_id_columns(self) -> list[str]: | |
return list(self.metadata.keys()) | |
def get_id_columns(self) -> list: | |
return list(self.metadata.values()) | |
def pos_path_filename(self, data_root: Path) -> Path: | |
return data_root / self.filename_fmt.format(**self.metadata) | |
def box_metadata_filename(self, data_root: Path) -> Path: | |
return data_root / "{date}_{subject}_top_000000000.json".format(**self.metadata) | |
@classmethod | |
def parse_experiment_spec_csv( | |
cls, filename: Path, metadata: list[str], filename_fmt: str, | |
title_root_fmt: str = "#{subject} ({date})", | |
) -> list["Experiment"]: | |
experiments = [] | |
with open(filename, "r") as fh: | |
reader = csv.reader(fh) | |
header = next(reader) | |
for row in reader: | |
metadata_values = {m: row[i] for i, m in enumerate(metadata)} | |
times = [] | |
for t in row[len(metadata):]: | |
if ":" in t: | |
min, sec = map(int, t.split(":")) | |
times.append(min * 60 + sec) | |
else: | |
times.append(int(t)) | |
experiment = cls( | |
metadata=metadata_values, filename_fmt=filename_fmt, | |
pre_start=times[0], pre_end=times[1], | |
trial_start=times[2], trial_end=times[3], | |
post_start=times[4], post_end=times[5], title_root=title_root_fmt.format(**metadata_values), | |
) | |
experiments.append(experiment) | |
return experiments | |
def parse_box_metadata(self, data_root: Path, box_names: tuple[str] = ("farside", "nearside")): | |
filename = self.box_metadata_filename(data_root) | |
with open(filename, "r") as fh: | |
data = json.load(fh) | |
shape = None | |
for s in data["shapes"]: | |
for name in box_names: | |
if s["label"] == name: | |
shape = s | |
if shape is None: | |
raise ValueError(f"Cannot find {box_names} in the json file. {filename}") | |
points = np.array(shape["points"]) | |
min_x, min_y = np.min(points, axis=0) | |
max_x, max_y = np.max(points, axis=0) | |
w = max_x - min_x | |
h = max_y - min_y | |
self.box_center = int(min_x + w / 2), int(min_y + h / 2) | |
self.box_size = int(w), int(h) | |
self.image_size = int(data["imageWidth"]), int(data["imageHeight"]) | |
def set_box_metadata( | |
self, box_left: int, box_top: int, box_right: int , box_bottom: int, image_size: tuple[int, int] | |
): | |
self.box_center = int((box_right + box_left) / 2), int((box_bottom + box_top) / 2) | |
self.box_size = box_right - box_left, box_bottom - box_top | |
self.image_size = image_size | |
self.image_offset = 0, 0 | |
self.enlarged_image_size = self.image_size | |
self.largest_box_size = self.box_size | |
@classmethod | |
def enlarge_canvas(cls, experiments: list["Experiment"]) -> tuple[int, int]: | |
box_centers = np.array([e.box_center for e in experiments]) | |
box_sizes = np.array([e.box_size for e in experiments]) | |
image_sizes = np.array([e.image_size for e in experiments]) | |
max_image = np.max(image_sizes, axis=0) | |
image_centers = np.floor(image_sizes / 2) | |
adjusted_image_centers = np.floor(image_centers + (max_image[None, :] - image_sizes) / 2) | |
adjusted_image_offsets = adjusted_image_centers - image_centers | |
adjusted_box_centers = box_centers + adjusted_image_offsets | |
mean_adjusted_box_centers = np.floor(np.mean(adjusted_box_centers, axis=0)) | |
aligned_box_offsets = mean_adjusted_box_centers - adjusted_box_centers | |
# how much we need to offset our image data so our box center aligns with the mean box center of all boxes | |
total_image_offset = adjusted_image_offsets + aligned_box_offsets | |
min_offset = np.min(total_image_offset, axis=0) | |
max_offset = np.max(total_image_offset, axis=0) | |
final_image_size = max_image + max_offset - min_offset | |
max_box_size = np.max(box_sizes, axis=0) | |
for i, experiment in enumerate(experiments): | |
experiment.image_offset = tuple(map(int, total_image_offset[i, :])) | |
experiment.enlarged_image_size = tuple(map(int, final_image_size)) | |
experiment.largest_box_size = tuple(map(int, max_box_size)) | |
return tuple(map(int, final_image_size)) | |
def position_to_side( | |
self, x: int, y: int, split_horizontally: bool = True, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
): | |
cx, cy = self.box_center | |
if split_horizontally: | |
i = 0 if x < cx else 1 | |
else: | |
i = 0 if y < cy else 1 | |
return categoricals[i] | |
def convert_to_categoricals( | |
self, data_group: Union["SubjectPos", str], measure: MEASURE_TYPE, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
measure_options: dict | None = None, | |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
if measure_options is None: | |
measure_options = {} | |
if isinstance(data_group, str): | |
data_group = getattr(self, data_group) | |
if measure == "occupancy": | |
f = partial( | |
self.position_to_side, | |
split_horizontally=measure_options.get("split_horizontally", True), categoricals=categoricals, | |
) | |
times, time_diff, categoricals_index = data_group.transform_to_categorical_pos(f, categoricals) | |
elif "freezing" in measure: | |
freeze_i = categoricals.index(measure_options["freeze_name"]) | |
non_freeze_i = 1 if freeze_i == 0 else 0 | |
if measure == "motion_index_freezing": | |
index = data_group.motion_index[:-1] | |
valid = index >= 0 | |
times = data_group.times | |
time_diff = times[1:] - times[:-1] | |
assert np.all(time_diff >= 0) | |
index = index[valid] | |
time_diff = time_diff[valid] | |
times = times[:-1][valid] | |
else: | |
assert measure == "speed_freezing" | |
times, index, _, time_diff = data_group.calculate_speed() | |
categoricals_index = np.empty(len(times)) | |
freezing = index <= measure_options["threshold"] | |
categoricals_index[freezing] = freeze_i | |
categoricals_index[np.logical_not(freezing)] = non_freeze_i | |
else: | |
raise ValueError(f"Unknown measure {measure}") | |
return times, time_diff, categoricals_index | |
def position_to_grid_index(self, x, y, grid_width: int, grid_height: int) -> tuple[int, int]: | |
offset_x, offset_y = self.image_offset | |
cx, cy = self.box_center | |
bw, bh = self.box_size | |
largest_bw, largest_bh = self.largest_box_size | |
width_scale = largest_bw / bw | |
height_scale = largest_bh / bh | |
x = (x - cx) * width_scale + cx + offset_x | |
y = (y - cy) * height_scale + cy + offset_y | |
x = int(min(max(round(x), 0), grid_width - 1)) | |
y = int(min(max(round(y), 0), grid_height - 1)) | |
return x, y | |
def read_pos_track( | |
self, data_root: Path, hab_offset: float = 0, pre_duration: float | None = None, trial_offset: float = 0, | |
trial_duration: float | None = None, post_offset: float = 0, post_duration: float | None = None, | |
frame_rate: float = 30. | |
) -> None: | |
self.pos_data = SubjectPos.parse_csv_track(self.pos_path_filename(data_root), frame_rate=frame_rate) | |
start = self.pre_start + hab_offset | |
end = self.pre_end | |
if pre_duration and pre_duration > 0: | |
end = min(end, start + pre_duration) | |
elif pre_duration and pre_duration < 0: | |
start = max(start, end + pre_duration) | |
self.pre_pos_data = self.pos_data.extract_range(start, end) | |
start = self.trial_start + trial_offset | |
end = self.trial_end | |
if trial_duration and trial_duration > 0: | |
end = min(end, start + trial_duration) | |
elif trial_duration and trial_duration < 0: | |
start = max(start, end + trial_duration) | |
self.trial_pos_data = self.pos_data.extract_range(start, end) | |
start = self.post_start + post_offset | |
end = self.post_end | |
if post_duration and post_duration > 0: | |
end = min(end, start + post_duration) | |
elif post_duration and post_duration < 0: | |
start = max(start, end + post_duration) | |
self.post_pos_data = self.pos_data.extract_range(start, end) | |
def chop_pos_track( | |
self, hab_segments: list[tuple[float, float]]=(), trial_segments: list[tuple[float, float]]=(), | |
post_segments: list[tuple[float, float]]=(), | |
) -> None: | |
data = self.chopped_pos_data = {} | |
for name in ("Pre", "Trial", "Post"): | |
match name: | |
case "Pre": | |
segments = hab_segments | |
ts = self.pre_start | |
te = self.pre_end | |
case "Trial": | |
segments = trial_segments | |
ts = self.trial_start | |
te = self.trial_end | |
case "Post": | |
segments = post_segments | |
ts = self.post_start | |
te = self.post_end | |
case _: | |
assert False | |
for start, end in segments: | |
item = self.pos_data.extract_range(ts + start, min(ts + end, te)) | |
key = f"{name}_{int(start)}" if len(segments) > 1 else name | |
data[key] = item | |
@classmethod | |
def _get_data_items(cls, obj: Union["Experiment", None] = None) -> list[tuple[Union["SubjectPos", str], str]]: | |
a, b, c = cls.triplet_name | |
if obj is None: | |
res = [ | |
("pre_pos_data", a), | |
("trial_pos_data", b), | |
("post_pos_data", c), | |
] | |
else: | |
res = [ | |
(obj.pre_pos_data, a), | |
(obj.trial_pos_data, b), | |
(obj.post_pos_data, c), | |
] | |
return res | |
@classmethod | |
def _iter_periods_and_groups(cls, experiments, periods, axs, filter_args: list[dict] | None): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
it = iter(axs.flatten()) | |
for i, filter_group in enumerate(filter_args or [None, ]): | |
if n_groups > 1: | |
experiments_ = cls.filter(experiments, **filter_group) | |
else: | |
experiments_ = experiments | |
for j, (data_name, t) in enumerate(periods): | |
ax = next(it) | |
yield experiments_, i, filter_group, j, data_name, t, ax | |
def plot_occupancy( | |
self, grid_width: int, grid_height: int, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True, | |
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_occupancy( | |
grid_width, grid_height, pos_to_index=self.position_to_grid_index, fig=fig, ax=ax, | |
gaussian_sigma=gaussian_sigma, | |
intensity_limit=intensity_limit, frame_normalize=frame_normalize, scale_to_one=scale_to_one, | |
color_bar=not i, | |
x_label="X (pixels)", | |
y_label="Y (pixels)" if not i else "", | |
title="", | |
) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"{label} occupancy density") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_occupancy( | |
cls, experiments: list["Experiment"], grid_width: int, grid_height: int, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, title: str = "Subjects occupancy density", | |
frame_normalize: bool = True, experiment_normalize: bool = True, scale_to_one: bool = True, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, group_label: str = "", | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
occupancy = np.zeros((grid_width, grid_height)) | |
for experiment in experiments_: | |
getattr(experiment, data_name).calculate_occupancy( | |
occupancy, experiment.position_to_grid_index, frame_normalize, | |
) | |
if experiment_normalize: | |
occupancy /= len(experiments_) | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
SubjectPos.plot_occupancy_data( | |
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, | |
color_bar=not i and not j, | |
title="", | |
x_label="X (pixels)" if i == n_groups - 1 else "", | |
y_label=f"{group}Y (pixels)" if not j else "", | |
) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def export_multi_experiment_motion( | |
cls, experiments: list["Experiment"], filename: Path, | |
measure: Literal["motion_index", "speed"] = "motion_index", use_chopped_data: bool = False, | |
) -> None: | |
header = experiments[0].get_header_id_columns() + ["Stage", "Motion (mean)"] | |
filename.parent.mkdir(parents=True, exist_ok=True) | |
with open(filename, "w", newline="") as fh: | |
writer = csv.writer(fh, delimiter=",") | |
writer.writerow(header) | |
for experiment in experiments: | |
if use_chopped_data: | |
stages = list(experiment.chopped_pos_data.keys()) | |
stages_data = [experiment.chopped_pos_data[stage] for stage in stages] | |
else: | |
stages = "Pre", "Trial", "Post" | |
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
id_columns = experiment.get_id_columns() | |
for stage, data in zip(stages, stages_data): | |
if measure == "motion_index": | |
motion = data.motion_index | |
mean_motion = np.mean(motion[motion >= 0]) | |
elif measure == "speed": | |
_, motion, _, _ = data.calculate_speed() | |
mean_motion = np.mean(motion) | |
line = id_columns + [stage, mean_motion] | |
writer.writerow(map(str, line)) | |
def plot_motion_index( | |
self, y_limit: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = "" | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_motion_index( | |
fig, ax, **{"y_label": ""} if i else {}, | |
) | |
if y_limit is not None: | |
ax.set_ylim(0, y_limit) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"{label} motion index") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_motion( | |
cls, experiments: list["Experiment"], y_limit: float | None = None, title: str = "Subjects motion index", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, group_label: str = "", | |
measure: Literal["motion_index", "speed"] = "motion_index", | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
for experiment in experiments_: | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
kwargs = {"y_label": ""} | |
if i != n_groups - 1: | |
kwargs["x_label"] = "" | |
if not j: | |
label = "Motion index" if measure == "motion_index" else "Speed (px / s)" | |
kwargs["y_label"] = f"{group}{label}" | |
data = getattr(experiment, data_name) | |
getattr(data, f"plot_{measure}")(fig, ax, **kwargs) | |
if y_limit is not None: | |
ax.set_ylim(0, y_limit) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_motion_index_histogram( | |
self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
hist_range: tuple[float, float] = (0, 3), | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data = data.motion_index | |
data = data[data >= 0] | |
ax.hist(data, bins=n_bins, density=True, log=True) | |
ax.set_xlim(*hist_range) | |
ax.set_xlabel("Motion index") | |
if not i: | |
ax.set_ylabel("Density") | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"{label} motion index density") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_motion_index_histogram( | |
cls, experiments: list["Experiment"], n_bins=100, title: str = "Subjects motion index", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", hist_range: tuple[float, float]=(0, 3), | |
filter_args: list[dict] | None = None, group_label: str = "", | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
items = [] | |
for experiment in experiments_: | |
data = getattr(experiment, data_name).motion_index | |
items.append(data[data >= 0]) | |
if items: | |
ax.hist(np.concatenate(items), bins=n_bins, density=True, log=True) | |
ax.set_xlim(*hist_range) | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
if i == n_groups - 1: | |
ax.set_xlabel("Motion index") | |
if not j: | |
ax.set_ylabel(f"{group}Density") | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_speed( | |
self, y_limit: float | None = None, save_fig_root: None | Path = None, | |
save_fig_prefix: str = "" | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_speed( | |
fig, ax, **{"y_label": ""} if i else {}, | |
) | |
if y_limit is not None: | |
ax.set_ylim(0, y_limit) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"{label} speed") | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_speed_histogram( | |
self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
hist_range: tuple[float, float] = (0, 200), | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
_, data, _, _ = data.calculate_speed() | |
ax.hist(data, bins=n_bins, density=True, log=True) | |
ax.set_xlim(*hist_range) | |
ax.set_xlabel("Speed (px / s)") | |
if not i: | |
ax.set_ylabel("Density") | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"{label} Speed density") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_speed_histogram( | |
cls, experiments: list["Experiment"], n_bins=100, | |
title: str = "Subjects Speed (px / s)", hist_range: tuple[float, float]=(0, 200), | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, group_label: str = "" | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
items = [] | |
for experiment in experiments_: | |
_, data, _, _ = getattr(experiment, data_name).calculate_speed() | |
items.append(data) | |
ax.hist(np.concatenate(items), bins=n_bins, density=True, log=True) | |
ax.set_xlim(*hist_range) | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
if i == n_groups - 1: | |
ax.set_xlabel("Speed (px / s)") | |
if not j: | |
ax.set_ylabel(f"{group}Density") | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(f"{title}") | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_categorical_values( | |
self, measure: MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
for (data, title), ax in zip(self._get_data_items(self), axs.flatten()): | |
times, _, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, measure_options) | |
data.plot_categorical_values(times, categoricals_index, categoricals, fig, ax) | |
ax.set_title(title) | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"Subject {label}") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_side_of_box( | |
cls, experiments: list["Experiment"], measure: MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
for (data_name, t), ax in zip(cls._get_data_items(), axs.flatten()): | |
for experiment in experiments: | |
times, _, categoricals_index = experiment.convert_to_categoricals( | |
data_name, measure, categoricals, measure_options | |
) | |
getattr(experiment, data_name).plot_categorical_values(times, categoricals_index, categoricals, fig, ax) | |
ax.set_title(t) | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def _descriptor_to_point(self, point: tuple[str, str]): | |
cx, cy = self.box_center | |
bw, bh = self.box_size | |
match point[0]: | |
case "left": | |
x = cx - bw / 2 | |
case "right": | |
x = cx + bw / 2 | |
case "center": | |
x = cx | |
case _: | |
raise ValueError(f"Can't recognize {point[0]}") | |
match point[1]: | |
case "bottom": | |
y = cy + bh / 2 | |
case "top": | |
y = cy - bh / 2 | |
case "center": | |
y = cy | |
case _: | |
raise ValueError(f"Can't recognize {point[1]}") | |
return x, y | |
def plot_distance_from_point( | |
self, point: tuple[str, str], save_fig_root: None | Path = None, | |
save_fig_prefix: str = "", post_title: str = "distance from teaball corner", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
point_xy = self._descriptor_to_point(point) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_distance_from_point( | |
point_xy, fig, ax, **{"y_label": ""} if i else {}, | |
) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
fig.suptitle(f"{label} {post_title}") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_distance_from_point( | |
cls, experiments: list["Experiment"], point: tuple[str, str], | |
title: str = "Subjects distance from teaball corner", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, group_label: str = "", | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
for experiment in experiments_: | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
kwargs = {"y_label": ""} | |
if i != n_groups - 1: | |
kwargs["x_label"] = "" | |
if not j: | |
kwargs["y_label"] = f"{group}Distance (px)" | |
point_xy = experiment._descriptor_to_point(point) | |
getattr(experiment, data_name).plot_distance_from_point(point_xy, fig, ax, **kwargs) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def _get_categorical_percent( | |
cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray], | |
sorted_categoricals: tuple[str, ...], unit: Literal["percent", "times"]="times", | |
): | |
if unit == "times": | |
percents = np.empty((len(categoricals_index), len(sorted_categoricals))) | |
for e, (exp_categoricals_index, time_diff) in enumerate(zip(categoricals_index, times_diff)): | |
if len(exp_categoricals_index) != len(time_diff): | |
raise ValueError("Provided time diff and categorical index are not the same length") | |
for i in range(len(sorted_categoricals)): | |
percents[e, i] = np.sum(time_diff[exp_categoricals_index == i]) | |
elif unit == "percent": | |
counts = np.array([ | |
[np.sum(arr == i) for i in range(len(sorted_categoricals))] | |
for arr in categoricals_index | |
]) | |
percents = counts / np.sum(counts, axis=1, keepdims=True) * 100 | |
else: | |
raise ValueError(f"Unknown unit {unit}") | |
mean_prop = np.mean(percents, axis=0).squeeze() | |
return mean_prop, [prop.squeeze() for prop in percents] | |
@classmethod | |
def _plot_percents( | |
cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray], | |
sorted_categoricals: tuple[str, ...], | |
fig: plt.Figure, ax: plt.Axes, | |
x_label: str = "Teaball side", y_label: str = "% time spent", | |
unit: Literal["percent", "times"] = "times", | |
): | |
mean_prop, percents = cls._get_categorical_percent(categoricals_index, times_diff, sorted_categoricals, unit) | |
ax.bar(np.arange(len(sorted_categoricals)), mean_prop, tick_label=sorted_categoricals) | |
if len(categoricals_index) > 1: | |
for prop in percents: | |
ax.plot(np.arange(len(sorted_categoricals)), prop.squeeze(), "k.") | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
def plot_categorical_percent( | |
self, measure: MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
unit: Literal["percent", "times"] = "times", x_label = "Odor side", action_label: str = "spent in side", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
y_label = "% time spent" if unit == "percent" else "total time spent (s)" | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
_, time_diff, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, measure_options) | |
self._plot_percents( | |
[categoricals_index], [time_diff], categoricals, fig, ax, x_label, "" if i else y_label, unit, | |
) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
label = self.title_root.format(**self.metadata) | |
tp = "total" if unit == "times" else "%" | |
fig.suptitle(f"{label} {tp} time {action_label}") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_categorical_percent( | |
cls, experiments: list["Experiment"], measure: MEASURE_TYPE = "occupancy", | |
measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, group_label: str = "", | |
unit: Literal["percent", "times"] = "times", x_label: str = "Odor side", | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
categoricals_index_all = [] | |
times_all = [] | |
for experiment in experiments_: | |
_, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
data_name, measure, categoricals, measure_options | |
) | |
categoricals_index_all.append(categoricals_index) | |
times_all.append(times_diff) | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
kwargs = {"y_label": "", "x_label": x_label} | |
if i != n_groups - 1: | |
kwargs["x_label"] = "" | |
if not j: | |
label = "% time spent" if unit == "percent" else "total time spent (s)" | |
kwargs["y_label"] = f"{group}{label}" | |
cls._plot_percents(categoricals_index_all, times_all, categoricals, fig, ax, **kwargs, unit=unit) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def export_multi_experiment_categorical_percent( | |
cls, experiments: list["Experiment"], filename: Path, | |
measure: MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
unit: Literal["percent", "times"] = "times", | |
use_chopped_data: bool = False, | |
) -> None: | |
sign = "sec" if unit == "times" else "%" | |
header = experiments[0].get_header_id_columns() + ["Stage"] + [ | |
f"{label} {sign}" for label in categoricals | |
] | |
filename.parent.mkdir(parents=True, exist_ok=True) | |
with open(filename, "w", newline="") as fh: | |
writer = csv.writer(fh, delimiter=",") | |
writer.writerow(header) | |
for experiment in experiments: | |
if use_chopped_data: | |
stages = list(experiment.chopped_pos_data.keys()) | |
stages_data = [experiment.chopped_pos_data[stage] for stage in stages] | |
else: | |
stages = "Pre", "Trial", "Post" | |
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
id_columns = experiment.get_id_columns() | |
for stage, data in zip(stages, stages_data): | |
_, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
data, measure, categoricals, measure_options | |
) | |
mean_prop, _ = cls._get_categorical_percent([categoricals_index], [times_diff], categoricals, unit) | |
line = id_columns + [stage, *mean_prop] | |
writer.writerow(map(str, line)) | |
@classmethod | |
def plot_multi_experiment_merged_by_period_categorical_percent( | |
cls, experiments: list["Experiment"], filter_args: list[dict], | |
measure: MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", group_label: str = "", | |
unit: Literal["percent", "times"] = "times", x_label: str = "Teaball side", | |
): | |
n_groups = len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(1, n_periods, sharey=True, sharex=True) | |
axes = list(axs.flatten()) | |
bar_width = 1 / (n_groups + 1) | |
n_categoricals = len(categoricals) | |
for i, filter_group in enumerate(filter_args): | |
experiments_ = cls.filter(experiments, **filter_group) | |
for j, (data_name, t) in enumerate(periods): | |
categoricals_index_all = [] | |
times_all = [] | |
for experiment in experiments_: | |
_, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
data_name, measure, categoricals, measure_options | |
) | |
categoricals_index_all.append(categoricals_index) | |
times_all.append(times_diff) | |
ax = axes[j] | |
mean_prop, percents = cls._get_categorical_percent( | |
categoricals_index_all, times_all, categoricals, unit | |
) | |
ax.bar( | |
i * bar_width + np.arange(n_categoricals), mean_prop, bar_width, | |
label=group_label.format(**filter_group), | |
) | |
for prop in percents: | |
ax.plot(i * bar_width + np.arange(n_categoricals), prop, "k.") | |
ax.set_xlabel(x_label) | |
if not j: | |
label = "% time spent" if unit == "percent" else "total time spent (s)" | |
ax.set_ylabel(f"{label}") | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
for ax in axes: | |
ax.set_xticks(np.arange(n_categoricals) + (1 - bar_width) / 2 - bar_width / 2, categoricals) | |
ax.set_xlim(-bar_width, n_categoricals + bar_width) | |
handles, labels = axes[-1].get_legend_handles_labels() | |
fig.legend(handles, labels, ncols=n_groups, bbox_to_anchor=(0, 0), loc=2) | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_merged_by_group_categorical_percent( | |
cls, experiments: list["Experiment"], filter_args: list[dict], | |
measure: MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
categoricals: tuple[str, ...] = ("Teaball-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", show_titles: bool = True, | |
show_legend: bool = True, only_categorical: str | None = None, show_xlabel: bool = True, | |
unit: Literal["percent", "times"] = "times", group_label: str = "", x_label: str = "Teaball side" | |
): | |
n_groups = len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, 1, sharey=True, sharex=True) | |
axes = list(axs.flatten()) | |
bar_width = 1 / (n_periods + 1) | |
n_categoricals = len(categoricals) | |
categoricals_s = slice(0, n_categoricals) | |
n_categoricals_used = n_categoricals | |
categoricals_used = categoricals | |
if only_categorical: | |
i = categoricals.index(only_categorical) | |
categoricals_s = slice(i, i + 1) | |
n_categoricals_used = 1 | |
categoricals_used = [only_categorical] | |
for i, filter_group in enumerate(filter_args): | |
group = "" | |
if n_groups > 1: | |
group = group_label.format(**filter_group) | |
experiments_ = cls.filter(experiments, **filter_group) | |
for j, (data_name, t) in enumerate(periods): | |
categoricals_index_all = [] | |
times_all = [] | |
for experiment in experiments_: | |
_, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
data_name, measure, categoricals, measure_options | |
) | |
categoricals_index_all.append(categoricals_index) | |
times_all.append(times_diff) | |
ax = axes[i] | |
mean_prop, percents = cls._get_categorical_percent( | |
categoricals_index_all, times_all, categoricals, unit | |
) | |
ax.bar(j * bar_width + np.arange(n_categoricals_used), mean_prop[categoricals_s], bar_width, label=t) | |
for prop in percents: | |
ax.plot(j * bar_width + np.arange(n_categoricals_used), prop[categoricals_s], "k.") | |
if i == n_groups - 1 and show_xlabel: | |
ax.set_xlabel(x_label) | |
label = "% time spent" if unit == "percent" else "total time spent (s)" | |
ax.set_ylabel(f"{group}{label}") | |
if not j and show_titles: | |
ax.set_title(group_label.format(**filter_group)) | |
for ax in axes: | |
ax.set_xticks(np.arange(n_categoricals_used) + (1 - bar_width) / 2 - bar_width / 2, categoricals_used) | |
ax.set_xlim(-bar_width, n_categoricals_used - bar_width) | |
if show_legend: | |
handles, labels = axes[-1].get_legend_handles_labels() | |
fig.legend(handles, labels, ncols=1, bbox_to_anchor=(0, 0), loc=2) | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix, width_inch=3) | |
@classmethod | |
def export_multi_experiment_frames(cls, experiments: list["Experiment"], filename: Path) -> None: | |
header = experiments[0].get_header_id_columns() + [ | |
"total_frames", "pre_frames", "trial_frames", "post_frames", "total_duration", "pre_duration", | |
"trial_duration", "post_duration", | |
] | |
filename.parent.mkdir(parents=True, exist_ok=True) | |
with open(filename, "w", newline="") as fh: | |
writer = csv.writer(fh, delimiter=",") | |
writer.writerow(header) | |
for experiment in experiments: | |
data_obj = ( | |
experiment.pos_data, experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
) | |
line = experiment.get_id_columns() + [ | |
*(len(obj.times) for obj in data_obj), | |
*((obj.times[-1] - obj.times[0]) if len(obj.times) else 0 for obj in data_obj), | |
] | |
writer.writerow(map(str, line)) | |
@classmethod | |
def filter(cls, experiments: list["Experiment"], **metadata): | |
for key, value in metadata.items(): | |
if value is not None: | |
experiments = [t for t in experiments if t.metadata[key] == value] | |
return experiments | |
@classmethod | |
def count_motion_index_range(cls, experiments: list["Experiment"]) -> dict[float, int]: | |
items = [] | |
for exp in experiments: | |
motion = exp.pos_data.motion_index | |
items.append(motion[motion >= 0]) | |
data = np.concatenate(items) | |
n = len(data) | |
counts = {} | |
zero_mask = data == 0 | |
counts[0] = np.sum(zero_mask).item() | |
data = data[np.logical_not(zero_mask)] | |
if not len(data): | |
return counts | |
min_val = np.min(data) | |
assert min_val > 0 | |
min_order = math.log10(min_val) | |
for order in range(int(math.ceil(-min_order)) + 1): | |
val = math.pow(10, -order) | |
mask = data >= val | |
counts[val] = np.sum(mask).item() | |
data = data[np.logical_not(mask)] | |
assert not len(data) | |
assert sum(counts.values()) == n | |
return counts | |
@dataclass | |
class SubjectPos: | |
filename: str | Path | |
times: np.ndarray | |
track: np.ndarray = field(repr=False) | |
motion_index: np.ndarray = field(repr=False) | |
_downsampled_key: Any = field(default=None, init=False, repr=False) | |
_downsampled_values: np.ndarray = field(default=None, init=False, repr=False) | |
_point_marker: dict = field(default_factory=dict, init=False, repr=False) | |
@property | |
def min_x(self): | |
return np.min(self.track[:, 0]) | |
@property | |
def min_y(self): | |
return np.min(self.track[:, 1]) | |
@property | |
def max_x(self): | |
return np.max(self.track[:, 0]) | |
@property | |
def max_y(self): | |
return np.max(self.track[:, 1]) | |
@classmethod | |
def parse_csv_track(cls, filename: Path, subject_name: str = "mouse", frame_rate: float = 30.) -> "SubjectPos": | |
track = [] | |
times = [] | |
index = [] | |
print(f"Reading {filename}") | |
with open(filename, "r") as fh: | |
reader = csv.reader(fh) | |
header = list(next(reader)) | |
last_t = None | |
last_frame_t = None | |
for i, row in enumerate(reader): | |
# if i % 2: | |
# continue | |
frame, instance, cx, cy, index_val, t1, t2 = row | |
# frame = int(frame) // 2 | |
if header[-1] == "real_timestamp_sec": | |
t = t2 | |
elif header[-2] == "real_timestamp_sec": | |
t = t1 | |
else: | |
raise ValueError("Cannot find the real_timestamp_sec column") | |
if instance != subject_name: | |
continue | |
track.append((int(float(cx)), int(float(cy)))) | |
index.append(float(index_val)) | |
t = float(t) | |
frame_t = int(frame) / frame_rate | |
if not (frame_t * 0.95 <= t <= frame_t * 1.05): | |
raise ValueError(f"{filename} frame rate of {frame_rate} does not match time stamp at \"{row}\"") | |
if last_t is not None and last_t >= t or last_frame_t is not None and last_frame_t >= frame_t: | |
raise ValueError(f"{filename} sequential times are not increasing \"{row}\"") | |
last_t = t | |
last_frame_t = frame_t | |
times.append(frame_t) | |
times = np.array(times) | |
return SubjectPos( | |
filename=filename, times=np.array(times), track=np.array(track), motion_index=np.array(index)) | |
def extract_range(self, t_start: float | None = None, t_end: float | None = None): | |
if t_start is None: | |
t_start = self.times[0] | |
if t_end is None: | |
t_end = self.times[-1] + 1 | |
i_s = np.sum(self.times < t_start) | |
i_e = np.sum(self.times <= t_end) | |
return SubjectPos( | |
filename=self.filename, times=self.times[i_s:i_e], track=self.track[i_s:i_e], | |
motion_index=self.motion_index[i_s:i_e] | |
) | |
def calculate_occupancy( | |
self, occupancy: np.ndarray, pos_to_index: Callable | None = None, frame_normalize: bool = True, | |
) -> None: | |
n = self.track.shape[0] | |
if not n: | |
return | |
grid_width, grid_height = occupancy.shape | |
frame_proportion = 1 | |
if frame_normalize: | |
frame_proportion = 1 / n | |
for i in range(n): | |
if self.track[i, 0] < 0 or self.track[i, 1] < 0: | |
continue | |
if pos_to_index is None: | |
x, y = self.track[i, :] | |
else: | |
x, y = pos_to_index(*self.track[i, :], grid_width, grid_height) | |
x = int(min(x, grid_width)) | |
y = int(min(y, grid_height)) | |
occupancy[x, y] += frame_proportion | |
def calculate_speed(self, downsample_factor: int = 8) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
# key = downsample_factor, | |
# | |
# if key != self._downsampled_key or self._downsampled_values is None: | |
# if len(self.times) <= 27: | |
# return np.empty(0), np.empty(0) | |
# | |
# values = np.concatenate([np.arange(len(self.times))[:, None], self.times[:, None], self.track], axis=1) | |
# decimated = decimate(values, downsample_factor, axis=0) | |
# | |
# self._downsampled_key = key | |
# self._downsampled_values = decimated | |
# | |
# indices = self._downsampled_values[:, 0] | |
# times = self._downsampled_values[:, 1] | |
# track = self._downsampled_values[:, 2:] | |
track = self.track | |
valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0) | |
track = track[valid, :] | |
times = self.times[valid] | |
elapsed = times[1:] - times[:-1] | |
if np.any(elapsed <= 0): | |
raise ValueError(f"Found sequential time stamps whose difference is not positive {self}") | |
dist = np.sqrt(np.sum(np.square(track[1:, :] - track[:-1, :]), axis=1, keepdims=False)) | |
speed = dist / elapsed | |
return times[1:], speed, dist, elapsed | |
@classmethod | |
def plot_occupancy_data( | |
cls, occupancy: np.ndarray, fig: plt.Figure, ax: plt.Axes, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, scale_to_one: bool = True, | |
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy", | |
color_bar: bool = True, | |
): | |
if gaussian_sigma: | |
occupancy = gaussian_filter(occupancy, gaussian_sigma) | |
if scale_to_one: | |
max_val = occupancy.max() | |
if max_val: | |
occupancy /= max_val | |
else: | |
occupancy[:] = 0 | |
im = ax.imshow( | |
occupancy.T, aspect="auto", origin="upper", cmap="viridis", interpolation="sinc", | |
interpolation_stage="data", vmax=intensity_limit or None | |
) | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if title: | |
ax.set_title(title) | |
if color_bar: | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.05) | |
fig.colorbar(im, cax=cax, orientation='vertical') | |
def plot_occupancy( | |
self, grid_width: int, grid_height: int, | |
pos_to_index: Callable | None = None, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True, | |
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy", | |
color_bar: bool = True, | |
): | |
occupancy = np.zeros((grid_width, grid_height)) | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
self.calculate_occupancy(occupancy, pos_to_index, frame_normalize) | |
self.plot_occupancy_data( | |
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, x_label, y_label, title, color_bar, | |
) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_motion_index( | |
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Time (min)", y_label: str = "Motion index", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
valid = self.motion_index >= 0 | |
times = self.times[valid] | |
if len(times): | |
ax.plot(times / 60 - self.times[0] / 60, self.motion_index[valid], **SubjectPos._point_marker) | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_speed( | |
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Time (min)", y_label: str = "Speed (px / s)", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
times, speed, _, _ = self.calculate_speed() | |
if len(times): | |
ax.plot(times / 60 - self.times[0] / 60, speed, **SubjectPos._point_marker) | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def transform_to_categorical_pos( | |
self, position_to_side: Callable, categoricals: tuple[str, ...], | |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
track = self.track[:-1, :] | |
valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0) | |
pos_categoricals_name = [position_to_side(*p) for p in track[valid, :]] | |
categoricals = {name: i for i, name in enumerate(categoricals)} | |
categoricals_index = np.array([categoricals[n] for n in pos_categoricals_name]) | |
time_diff = self.times[1:] - self.times[:-1] | |
assert np.all(time_diff >= 0) | |
time_diff = time_diff[valid] | |
times = self.times[:-1][valid] | |
return times, time_diff, categoricals_index | |
def plot_categorical_values( | |
self, times, categoricals_index, categoricals: tuple[str, ...], | |
fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
ax.plot(times / 60 - self.times[0] / 60, categoricals_index, "*", alpha=.3, ms=5) | |
ax.set_yticks(np.arange(len(categoricals)), categoricals) | |
ax.set_xlabel("Time (min)") | |
ax.set_ylabel("Side of box relative to teaball") | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_distance_from_point( | |
self, point_xy: tuple[int, int], fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Time (min)", y_label: str = "Distance (px)", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0) | |
point = np.array(point_xy)[None, :] | |
distance = np.sqrt(np.sum(np.square(self.track[valid, :] - point), axis=1)) | |
times = self.times[valid] | |
if len(times): | |
ax.plot(times / 60 - self.times[0] / 60, distance, **SubjectPos._point_marker) | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def export_experiments_csv( | |
experiments: list[Experiment], data_root: Path, occupancy_meas_args, cue_categoricals, | |
freezing_index_meas_args, freezing_speed_meas_args, freeze_categoricals, cat_unit, include_chopped: bool | |
): | |
Experiment.export_multi_experiment_frames(experiments, data_root / "experiment_metadata.csv") | |
Experiment.export_multi_experiment_categorical_percent( | |
experiments, data_root / "box_sides.csv", measure="occupancy", measure_options=occupancy_meas_args, | |
categoricals=cue_categoricals, unit=cat_unit, | |
) | |
Experiment.export_multi_experiment_categorical_percent( | |
experiments, data_root / f"motion_index_freezing_threshold_{freezing_index_meas_args['threshold']}.csv", | |
measure="motion_index_freezing", | |
measure_options=freezing_index_meas_args, categoricals=freeze_categoricals, unit=cat_unit, | |
) | |
Experiment.export_multi_experiment_categorical_percent( | |
experiments, data_root / f"speed_freezing_threshold_{freezing_speed_meas_args['threshold']}.csv", | |
measure="speed_freezing", | |
measure_options=freezing_speed_meas_args, | |
categoricals=freeze_categoricals, unit=cat_unit, | |
) | |
Experiment.export_multi_experiment_motion(experiments, data_root / "motion_index_mean.csv", measure="motion_index") | |
Experiment.export_multi_experiment_motion(experiments, data_root / "speed_mean.csv", measure="speed") | |
if include_chopped: | |
Experiment.export_multi_experiment_categorical_percent( | |
experiments, data_root / "box_sides_pre.csv", use_chopped_data=True, | |
measure="occupancy", measure_options=occupancy_meas_args, | |
) | |
Experiment.export_multi_experiment_motion( | |
experiments, data_root / "motion_index_mean_pre.csv", use_chopped_data=True, | |
) | |
Experiment.export_multi_experiment_motion( | |
experiments, data_root / "speed_mean_pre.csv", use_chopped_data=True, measure="speed", | |
) | |
def export_single_subject_figures( | |
experiments: list[Experiment], figure_root: Path, grid_width, grid_height, occupancy_meas_args, | |
cue_point, cat_unit, prefix_pat: str, | |
): | |
for experiment in tqdm.tqdm(experiments, desc="subject"): | |
label = prefix_pat.format(**experiment.metadata) | |
experiment.plot_occupancy( | |
grid_width, grid_height, scale_to_one=False, intensity_limit=1e-5, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"intensity_limit100K_{label}", | |
) | |
experiment.plot_occupancy( | |
grid_width, grid_height, intensity_limit=1e-6, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"intensity_limit1M_{label}", | |
) | |
experiment.plot_occupancy( | |
grid_width, grid_height, gaussian_sigma=5, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"occupancy_sigma5_{label}", | |
) | |
experiment.plot_occupancy( | |
grid_width, grid_height, gaussian_sigma=1, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"occupancy_sigma1_{label}", | |
) | |
experiment.plot_motion_index_histogram( | |
n_bins=50, hist_range=(0, 3), | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"motion_index_histogram_{label}", | |
) | |
experiment.plot_motion_index( | |
y_limit=1, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"motion_index_{label}", | |
) | |
experiment.plot_speed_histogram( | |
n_bins=50, hist_range=(0, 800), | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"speed_histogram_{label}", | |
) | |
experiment.plot_speed( | |
y_limit=200, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"speed_{label}", | |
) | |
experiment.plot_distance_from_point( | |
cue_point, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"distance_{label}", | |
) | |
experiment.plot_categorical_percent( | |
measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit, | |
save_fig_root=figure_root / "subject" / label, | |
save_fig_prefix=f"side_{label}", x_label="", | |
) | |
def export_multi_experiment_figures( | |
experiments: list[Experiment], figure_root: Path, grid_width, grid_height, occupancy_meas_args, | |
cat_unit, label: str, cue_categoricals, freezing_index_meas_args, cue_point, | |
freezing_speed_meas_args, freeze_categoricals, | |
filter_args: list[dict] | None = None, group_label: str = "", sub_dir: str = "grouped", | |
): | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, | |
# title=f"{label} occupancy density", scale_to_one=False, | |
# intensity_limit=1e-5, filter_args=filter_args, group_label=group_label, | |
# save_fig_root=figure_root / sub_dir / "occupancy" / "100K", | |
# save_fig_prefix=f"occupancy_{label}_intensity_limit100K", | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, filter_args=filter_args, group_label=group_label, | |
# title=f"{label} occupancy density", intensity_limit=1e-6, | |
# save_fig_root=figure_root / sub_dir / "occupancy" / "1M", | |
# save_fig_prefix=f"occupancy_{label}_intensity_limit1M", | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, title=f"{label} occupancy", gaussian_sigma=5, | |
# save_fig_root=figure_root / sub_dir / "occupancy" / "sigma5", | |
# save_fig_prefix=f"occupancy_{label}_sigma5", | |
# filter_args=filter_args, group_label=group_label, | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, title=f"{label} occupancy", gaussian_sigma=1, | |
# save_fig_root=figure_root / sub_dir / "occupancy" / "sigma1", | |
# save_fig_prefix=f"occupancy_{label}_sigma1", | |
# filter_args=filter_args, group_label=group_label, | |
# ) | |
Experiment.plot_multi_experiment_motion_index_histogram( | |
experiments, title=f"{label} motion index density", n_bins=50, filter_args=filter_args, group_label=group_label, | |
save_fig_root=figure_root / sub_dir / "motion_index_histogram", hist_range=(0, 3), | |
save_fig_prefix=f"motion_index_histogram_{label}", | |
) | |
Experiment.plot_multi_experiment_motion( | |
experiments, y_limit=1, title=f"{label} motion index", measure="motion_index", | |
save_fig_root=figure_root / sub_dir / "motion_index", | |
save_fig_prefix=f"motion_index_{label}", | |
filter_args=filter_args, group_label=group_label, | |
) | |
Experiment.plot_multi_experiment_speed_histogram( | |
experiments, title=f"{label} speed density", n_bins=50, filter_args=filter_args, group_label=group_label, | |
hist_range=(0, 800), | |
save_fig_root=figure_root / sub_dir / "speed_histogram", | |
save_fig_prefix=f"speed_histogram_{label}", | |
) | |
Experiment.plot_multi_experiment_motion( | |
experiments, y_limit=200, title=f"{label} speed", measure="speed", | |
filter_args=filter_args, group_label=group_label, | |
save_fig_root=figure_root / sub_dir / "speed", | |
save_fig_prefix=f"speed_{label}", | |
) | |
# Experiment.plot_multi_experiment_distance_from_point( | |
# experiments, cue_point, title=f"{label} distance from cue", | |
# filter_args=filter_args, group_label=group_label, | |
# save_fig_root=figure_root / sub_dir / "distance", save_fig_prefix=f"distance_{label}", | |
# ) | |
# | |
# Experiment.plot_multi_experiment_categorical_percent( | |
# experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args, | |
# categoricals=cue_categoricals, unit=cat_unit, | |
# filter_args=filter_args, group_label=group_label, x_label="", | |
# save_fig_root=figure_root / sub_dir / "side" / "side", | |
# save_fig_prefix=f"side_{label}", | |
# ) | |
Experiment.plot_multi_experiment_categorical_percent( | |
experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}", | |
measure="motion_index_freezing", | |
measure_options=freezing_index_meas_args, | |
categoricals=freeze_categoricals, unit=cat_unit, | |
filter_args=filter_args, group_label=group_label, x_label="", | |
save_fig_root=figure_root / sub_dir / "motion_index_freezing" / "freezing", | |
save_fig_prefix=f"freezing_{label}", | |
) | |
Experiment.plot_multi_experiment_categorical_percent( | |
experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}", | |
measure="speed_freezing", | |
measure_options=freezing_speed_meas_args, | |
categoricals=freeze_categoricals, unit=cat_unit, | |
filter_args=filter_args, group_label=group_label, x_label="", | |
save_fig_root=figure_root / sub_dir / "speed_freezing" / "freezing", | |
save_fig_prefix=f"freezing_{label}", | |
) | |
# | |
# if not filter_args: | |
# return | |
# | |
# Experiment.plot_multi_experiment_merged_by_period_categorical_percent( | |
# experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args, | |
# categoricals=cue_categoricals, x_label="", unit="times", | |
# filter_args=filter_args, group_label=group_label, | |
# save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_period", | |
# save_fig_prefix=f"side_merged_by_period_{label}", | |
# ) | |
# Experiment.plot_multi_experiment_merged_by_group_categorical_percent( | |
# experiments, title=f"Total time spent in cue side", show_titles=False, show_legend=True, | |
# measure="occupancy", measure_options=occupancy_meas_args, unit="times", | |
# categoricals=cue_categoricals, x_label="Cue side", | |
# filter_args=filter_args, only_categorical="Cue-side", show_xlabel=False, | |
# save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_group", | |
# save_fig_prefix=f"side_merged_by_group_{label}", group_label=group_label, | |
# ) | |
Experiment.plot_multi_experiment_merged_by_period_categorical_percent( | |
experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}", | |
measure="motion_index_freezing", | |
measure_options=freezing_index_meas_args, | |
categoricals=freeze_categoricals, x_label="", unit="times", | |
filter_args=filter_args, group_label=group_label, | |
save_fig_root=figure_root / sub_dir / "motion_index_freezing" / | |
"freezing_merged_by_period", | |
save_fig_prefix=f"freezing_merged_by_period_{label}", | |
) | |
Experiment.plot_multi_experiment_merged_by_group_categorical_percent( | |
experiments, | |
title=f"Total time spent freezing (motion index). Threshold={freezing_index_meas_args['threshold']}", | |
show_titles=False, show_legend=True, | |
measure="motion_index_freezing", measure_options=freezing_index_meas_args, unit="times", | |
categoricals=freeze_categoricals, x_label="Freezing", | |
filter_args=filter_args, only_categorical="Freezing", show_xlabel=False, | |
save_fig_root=figure_root / sub_dir / "motion_index_freezing" / | |
"freezing_merged_by_group", | |
save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label, | |
) | |
Experiment.plot_multi_experiment_merged_by_period_categorical_percent( | |
experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}", | |
measure="speed_freezing", | |
measure_options=freezing_speed_meas_args, | |
categoricals=freeze_categoricals, x_label="", unit="times", | |
filter_args=filter_args, group_label=group_label, | |
save_fig_root=figure_root / sub_dir / "speed_freezing" / | |
"freezing_merged_by_period", | |
save_fig_prefix=f"freezing_merged_by_period_{label}", | |
) | |
Experiment.plot_multi_experiment_merged_by_group_categorical_percent( | |
experiments, | |
title=f"Total time spent freezing (speed). Threshold={freezing_speed_meas_args['threshold']}", | |
show_titles=False, show_legend=True, | |
measure="speed_freezing", measure_options=freezing_speed_meas_args, unit="times", | |
categoricals=freeze_categoricals, x_label="Freezing", | |
filter_args=filter_args, only_categorical="Freezing", show_xlabel=False, | |
save_fig_root=figure_root / sub_dir / "speed_freezing" / | |
"freezing_merged_by_group", | |
save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label, | |
) | |
def run_yidan(): | |
SubjectPos._point_marker = {"marker": ".", "markersize": 2, "linestyle": "", "markeredgecolor": 'none'} | |
all_experiments = [] | |
chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(10)] | |
metadata = ["date", "subject", "strain", "condition", "cell_label", "exposure", "sex"] | |
filename_fmt = "{date}_{subject}_top_tracked.csv" | |
Experiment.triplet_name = "Pre Trial", "Trial", "Post Trial" | |
occupancy_meas_args = {"split_horizontally": True} | |
freezing_index_meas_args = {"threshold": 0, "freeze_name": "Freezing"} | |
freezing_speed_meas_args = {"threshold": 0, "freeze_name": "Freezing"} | |
cue_categoricals = "Cue-side", "Far-side" | |
freeze_categoricals = "Freezing", "Non-freezing" | |
root = Path(r'D:\code_data\yidan_2025') | |
figure_root = root / "results" / "figures" | |
data_root = root / "results" / "data" | |
for num in (3, 5, 6, 7): | |
csv_experiment_times = root / f"behavioral data timestamp - b{num} timestamps.csv" | |
tracking_data_root = root / f"batch{num}" | |
json_data_root = root / f"batch{num}" / "json" | |
experiments = Experiment.parse_experiment_spec_csv( | |
csv_experiment_times, metadata, filename_fmt, "#{subject} ({exposure} - {date})" | |
) | |
all_experiments.extend(experiments) | |
for experiment in experiments: | |
if num == 3: | |
experiment.parse_box_metadata(json_data_root) | |
else: | |
experiment.set_box_metadata(200, 132, 784, 486, (960, 600)) | |
experiment.read_pos_track( | |
tracking_data_root, hab_offset=3 * 60, pre_duration=2 * 60, trial_offset=0, trial_duration=2 * 60, | |
post_offset=0, post_duration=2 * 60, frame_rate=41.52, | |
) | |
experiment.chop_pos_track(hab_segments=chopped_pre_times) | |
experiments = all_experiments | |
grid_width, grid_height = Experiment.enlarge_canvas(experiments) | |
pprint.pprint(Experiment.count_motion_index_range(experiments)) | |
export_experiments_csv( | |
experiments, data_root, occupancy_meas_args, cue_categoricals, freezing_index_meas_args, | |
freezing_speed_meas_args, freeze_categoricals, "times", True, | |
) | |
# export_single_subject_figures( | |
# experiments, figure_root, grid_width, grid_height, occupancy_meas_args, ("left", "bottom"), "times", | |
# "{subject}_exposure-{exposure}_{date}", | |
# ) | |
conditions = "Blank", "TMT", "2MBA", "IAMM" | |
strains = "TRAP1", | |
exposures = "1", "2" | |
sex = "M", "F" | |
cond_x_strain_x_expo = [(c, s, e) for c in conditions for s in strains for e in exposures] | |
for condition, strain, expo in cond_x_strain_x_expo: | |
experiments_ = Experiment.filter(experiments, condition=condition, exposure=expo, strain=strain) | |
label = f"{condition}-{expo}-{strain}" | |
export_multi_experiment_figures( | |
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label, | |
cue_categoricals, freezing_index_meas_args, ("left", "bottom"), freezing_speed_meas_args, | |
freeze_categoricals, None, "", "grouped", | |
) | |
filter_args = [{"condition": c} for c in conditions] | |
items = [ | |
({"exposure": e, "strain": s}, filter_args, "by_exposure") for s in strains for e in exposures | |
] | |
filter_args = [{"exposure": e} for e in exposures] | |
items.extend([ | |
({"condition": c, "strain": s}, filter_args, "by_condition") for c in conditions for s in strains | |
]) | |
filter_args = [{"exposure": e, "sex": s} for e in exposures for s in sex] | |
items.extend([ | |
({"condition": c, "strain": s}, filter_args, "by_condition_w_sex") for c in conditions for s in strains | |
]) | |
for groups, filter_args, sub_dir in items: | |
experiments_ = Experiment.filter(experiments, **groups) | |
group_label = "$\\bf" + "\n".join("{{{" + f + "}}}" for f in filter_args[0].keys()) + "$\n\n" | |
label = "-".join(groups.values()) | |
export_multi_experiment_figures( | |
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label, | |
cue_categoricals, freezing_index_meas_args, ("left", "bottom"), freezing_speed_meas_args, | |
freeze_categoricals, filter_args, group_label, f"grouped_single_figure/{sub_dir}" | |
) | |
def run_cynthia(): | |
SubjectPos._point_marker = {"marker": ".", "markersize": 1.5, "linestyle": ""} | |
chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(8)] | |
metadata = ["date", "subject", "stage", "condition", "cue", "sex"] | |
filename_fmt = "{subject}_{stage}_{date}_tracked.csv" | |
grid_width, grid_height = 641, 480 | |
Experiment.triplet_name = "Pre Trial", "Trial", "Post Trial" | |
occupancy_meas_args = {"split_horizontally": False} | |
freezing_index_meas_args = {"threshold": 0, "freeze_name": "Freezing"} | |
freezing_speed_meas_args = {"threshold": 0, "freeze_name": "Freezing"} | |
cue_categoricals = "Cue-side", "Far-side" | |
freeze_categoricals = "Freezing", "Non-freezing" | |
root = Path(r'D:\code_data\cynthia_sp_2025') | |
csv_experiment_times = root / "Behavioral timestamps Cynthia.csv" | |
tracking_data_root = root / "tracking" | |
figure_root = root / "results" / "figures" | |
data_root = root / "results" / "data" | |
experiments = Experiment.parse_experiment_spec_csv( | |
csv_experiment_times, metadata, filename_fmt, "#{subject} ({stage})", | |
) | |
for experiment in experiments: | |
experiment.set_box_metadata(0, 0, grid_width, grid_height, (grid_width, grid_height)) | |
experiment.read_pos_track( | |
tracking_data_root, frame_rate=15, pre_duration=-28, post_duration=28, | |
) | |
# experiment.chop_pos_track(hab_segments=chopped_pre_times) | |
pprint.pprint(Experiment.count_motion_index_range(experiments)) | |
# export_experiments_csv( | |
# experiments, data_root, occupancy_meas_args, cue_categoricals, freezing_index_meas_args, | |
# freezing_speed_meas_args, freeze_categoricals, "times", False, | |
# ) | |
# export_single_subject_figures( | |
# experiments, figure_root, grid_width, grid_height, occupancy_meas_args, ("center", "top"), "times", | |
# "{subject}_{stage}", | |
# ) | |
conditions = "control", "cued", "backward" | |
cues = "odor", "tone" | |
stages = "train", "test" | |
cond_x_cue_x_stage = [(c, s, st) for c in conditions for s in cues for st in stages] | |
for condition, cue, stage in cond_x_cue_x_stage: | |
experiments_ = Experiment.filter(experiments, condition=condition, cue=cue, stage=stage) | |
label = f"{condition}-{cue}-{stage}" | |
export_multi_experiment_figures( | |
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label, | |
cue_categoricals, freezing_index_meas_args, ("center", "top"), freezing_speed_meas_args, | |
freeze_categoricals, None, "", "grouped", | |
) | |
filter_args = [{"stage": s} for s in stages] | |
items = [({"condition": c, "cue": s}, filter_args, "train_vs_test") for c in conditions for s in cues] | |
filter_args = [{"condition": c} for c in conditions] | |
items.extend( | |
[({"cue": c, "stage": s}, filter_args, "by_condition") for c in cues for s in stages] | |
) | |
for groups, filter_args, sub_dir in items: | |
experiments_ = Experiment.filter(experiments, **groups) | |
group_label = "$\\bf" + "\n".join("{{{" + f + "}}}" for f in filter_args[0].keys()) + "$\n\n" | |
label = "-".join(groups.values()) | |
export_multi_experiment_figures( | |
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label, | |
cue_categoricals, freezing_index_meas_args, ("center", "top"), freezing_speed_meas_args, | |
freeze_categoricals, filter_args, group_label, f"grouped_single_figure/{sub_dir}" | |
) | |
if __name__ == "__main__": | |
# run_yidan() | |
run_cynthia() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment