Skip to content

Instantly share code, notes, and snippets.

@matham
Last active June 2, 2025 22:06
Show Gist options
  • Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Export results for teaball experiments - sniffing, occupancy etc
from dataclasses import dataclass, field
import pprint
from typing import Callable, Union, Literal, Any
from pathlib import Path
import csv
from scipy.signal import decimate
from copy import deepcopy
import math
import tqdm
import numpy as np
from collections import defaultdict
from functools import partial
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import gaussian_filter
import json
MEASURE_TYPE = Literal["occupancy", "motion_index_freezing", "speed_freezing"]
def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = "", width_inch: int = 8, height_inch: int = 6):
if save_fig_root:
save_fig_root.mkdir(parents=True, exist_ok=True)
fig = plt.gcf()
fig.set_size_inches(width_inch, height_inch)
fig.tight_layout()
fig.savefig(
save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight',
dpi=300
)
plt.close()
else:
plt.tight_layout()
plt.show()
@dataclass
class Experiment:
pre_start: float
pre_end: float
trial_start: float
trial_end: float
post_start: float
post_end: float
filename_fmt: str
title_root: str
metadata: dict[str, Any] = field(default_factory=dict)
triplet_name: tuple[str, str, str] = "Habituation", "Trial", "Post Trial"
pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
pre_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
trial_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
post_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
chopped_pos_data: dict[str, "SubjectPos"] = field(default_factory=dict, init=False, repr=False)
box_center: tuple[int, int] = field(default=None, init=False)
box_size: tuple[int, int] = field(default=None, init=False)
image_size: tuple[int, int] = field(default=None, init=False)
image_offset: tuple[int, int] = field(default=None, init=False)
# how much we need to offset our box center so our box center aligns with the mean box center of all boxes
enlarged_image_size: tuple[int, int] = field(default=None, init=False, repr=False)
# size of the largest image/canvas we need to fit all experiments, including their offsets
largest_box_size: tuple[int, int] = field(default=None, init=False, repr=False)
# size of the largest box of all the experiments
def get_header_id_columns(self) -> list[str]:
return list(self.metadata.keys())
def get_id_columns(self) -> list:
return list(self.metadata.values())
def pos_path_filename(self, data_root: Path) -> Path:
return data_root / self.filename_fmt.format(**self.metadata)
def box_metadata_filename(self, data_root: Path) -> Path:
return data_root / "{date}_{subject}_top_000000000.json".format(**self.metadata)
@classmethod
def parse_experiment_spec_csv(
cls, filename: Path, metadata: list[str], filename_fmt: str,
title_root_fmt: str = "#{subject} ({date})",
) -> list["Experiment"]:
experiments = []
with open(filename, "r") as fh:
reader = csv.reader(fh)
header = next(reader)
for row in reader:
metadata_values = {m: row[i] for i, m in enumerate(metadata)}
times = []
for t in row[len(metadata):]:
if ":" in t:
min, sec = map(int, t.split(":"))
times.append(min * 60 + sec)
else:
times.append(int(t))
experiment = cls(
metadata=metadata_values, filename_fmt=filename_fmt,
pre_start=times[0], pre_end=times[1],
trial_start=times[2], trial_end=times[3],
post_start=times[4], post_end=times[5], title_root=title_root_fmt.format(**metadata_values),
)
experiments.append(experiment)
return experiments
def parse_box_metadata(self, data_root: Path, box_names: tuple[str] = ("farside", "nearside")):
filename = self.box_metadata_filename(data_root)
with open(filename, "r") as fh:
data = json.load(fh)
shape = None
for s in data["shapes"]:
for name in box_names:
if s["label"] == name:
shape = s
if shape is None:
raise ValueError(f"Cannot find {box_names} in the json file. {filename}")
points = np.array(shape["points"])
min_x, min_y = np.min(points, axis=0)
max_x, max_y = np.max(points, axis=0)
w = max_x - min_x
h = max_y - min_y
self.box_center = int(min_x + w / 2), int(min_y + h / 2)
self.box_size = int(w), int(h)
self.image_size = int(data["imageWidth"]), int(data["imageHeight"])
def set_box_metadata(
self, box_left: int, box_top: int, box_right: int , box_bottom: int, image_size: tuple[int, int]
):
self.box_center = int((box_right + box_left) / 2), int((box_bottom + box_top) / 2)
self.box_size = box_right - box_left, box_bottom - box_top
self.image_size = image_size
self.image_offset = 0, 0
self.enlarged_image_size = self.image_size
self.largest_box_size = self.box_size
@classmethod
def enlarge_canvas(cls, experiments: list["Experiment"]) -> tuple[int, int]:
box_centers = np.array([e.box_center for e in experiments])
box_sizes = np.array([e.box_size for e in experiments])
image_sizes = np.array([e.image_size for e in experiments])
max_image = np.max(image_sizes, axis=0)
image_centers = np.floor(image_sizes / 2)
adjusted_image_centers = np.floor(image_centers + (max_image[None, :] - image_sizes) / 2)
adjusted_image_offsets = adjusted_image_centers - image_centers
adjusted_box_centers = box_centers + adjusted_image_offsets
mean_adjusted_box_centers = np.floor(np.mean(adjusted_box_centers, axis=0))
aligned_box_offsets = mean_adjusted_box_centers - adjusted_box_centers
# how much we need to offset our image data so our box center aligns with the mean box center of all boxes
total_image_offset = adjusted_image_offsets + aligned_box_offsets
min_offset = np.min(total_image_offset, axis=0)
max_offset = np.max(total_image_offset, axis=0)
final_image_size = max_image + max_offset - min_offset
max_box_size = np.max(box_sizes, axis=0)
for i, experiment in enumerate(experiments):
experiment.image_offset = tuple(map(int, total_image_offset[i, :]))
experiment.enlarged_image_size = tuple(map(int, final_image_size))
experiment.largest_box_size = tuple(map(int, max_box_size))
return tuple(map(int, final_image_size))
def position_to_side(
self, x: int, y: int, split_horizontally: bool = True,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
):
cx, cy = self.box_center
if split_horizontally:
i = 0 if x < cx else 1
else:
i = 0 if y < cy else 1
return categoricals[i]
def convert_to_categoricals(
self, data_group: Union["SubjectPos", str], measure: MEASURE_TYPE,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
measure_options: dict | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if measure_options is None:
measure_options = {}
if isinstance(data_group, str):
data_group = getattr(self, data_group)
if measure == "occupancy":
f = partial(
self.position_to_side,
split_horizontally=measure_options.get("split_horizontally", True), categoricals=categoricals,
)
times, time_diff, categoricals_index = data_group.transform_to_categorical_pos(f, categoricals)
elif "freezing" in measure:
freeze_i = categoricals.index(measure_options["freeze_name"])
non_freeze_i = 1 if freeze_i == 0 else 0
if measure == "motion_index_freezing":
index = data_group.motion_index[:-1]
valid = index >= 0
times = data_group.times
time_diff = times[1:] - times[:-1]
assert np.all(time_diff >= 0)
index = index[valid]
time_diff = time_diff[valid]
times = times[:-1][valid]
else:
assert measure == "speed_freezing"
times, index, _, time_diff = data_group.calculate_speed()
categoricals_index = np.empty(len(times))
freezing = index <= measure_options["threshold"]
categoricals_index[freezing] = freeze_i
categoricals_index[np.logical_not(freezing)] = non_freeze_i
else:
raise ValueError(f"Unknown measure {measure}")
return times, time_diff, categoricals_index
def position_to_grid_index(self, x, y, grid_width: int, grid_height: int) -> tuple[int, int]:
offset_x, offset_y = self.image_offset
cx, cy = self.box_center
bw, bh = self.box_size
largest_bw, largest_bh = self.largest_box_size
width_scale = largest_bw / bw
height_scale = largest_bh / bh
x = (x - cx) * width_scale + cx + offset_x
y = (y - cy) * height_scale + cy + offset_y
x = int(min(max(round(x), 0), grid_width - 1))
y = int(min(max(round(y), 0), grid_height - 1))
return x, y
def read_pos_track(
self, data_root: Path, hab_offset: float = 0, pre_duration: float | None = None, trial_offset: float = 0,
trial_duration: float | None = None, post_offset: float = 0, post_duration: float | None = None,
frame_rate: float = 30.
) -> None:
self.pos_data = SubjectPos.parse_csv_track(self.pos_path_filename(data_root), frame_rate=frame_rate)
start = self.pre_start + hab_offset
end = self.pre_end
if pre_duration and pre_duration > 0:
end = min(end, start + pre_duration)
elif pre_duration and pre_duration < 0:
start = max(start, end + pre_duration)
self.pre_pos_data = self.pos_data.extract_range(start, end)
start = self.trial_start + trial_offset
end = self.trial_end
if trial_duration and trial_duration > 0:
end = min(end, start + trial_duration)
elif trial_duration and trial_duration < 0:
start = max(start, end + trial_duration)
self.trial_pos_data = self.pos_data.extract_range(start, end)
start = self.post_start + post_offset
end = self.post_end
if post_duration and post_duration > 0:
end = min(end, start + post_duration)
elif post_duration and post_duration < 0:
start = max(start, end + post_duration)
self.post_pos_data = self.pos_data.extract_range(start, end)
def chop_pos_track(
self, hab_segments: list[tuple[float, float]]=(), trial_segments: list[tuple[float, float]]=(),
post_segments: list[tuple[float, float]]=(),
) -> None:
data = self.chopped_pos_data = {}
for name in ("Pre", "Trial", "Post"):
match name:
case "Pre":
segments = hab_segments
ts = self.pre_start
te = self.pre_end
case "Trial":
segments = trial_segments
ts = self.trial_start
te = self.trial_end
case "Post":
segments = post_segments
ts = self.post_start
te = self.post_end
case _:
assert False
for start, end in segments:
item = self.pos_data.extract_range(ts + start, min(ts + end, te))
key = f"{name}_{int(start)}" if len(segments) > 1 else name
data[key] = item
@classmethod
def _get_data_items(cls, obj: Union["Experiment", None] = None) -> list[tuple[Union["SubjectPos", str], str]]:
a, b, c = cls.triplet_name
if obj is None:
res = [
("pre_pos_data", a),
("trial_pos_data", b),
("post_pos_data", c),
]
else:
res = [
(obj.pre_pos_data, a),
(obj.trial_pos_data, b),
(obj.post_pos_data, c),
]
return res
@classmethod
def _iter_periods_and_groups(cls, experiments, periods, axs, filter_args: list[dict] | None):
n_groups = 1 if filter_args is None else len(filter_args)
it = iter(axs.flatten())
for i, filter_group in enumerate(filter_args or [None, ]):
if n_groups > 1:
experiments_ = cls.filter(experiments, **filter_group)
else:
experiments_ = experiments
for j, (data_name, t) in enumerate(periods):
ax = next(it)
yield experiments_, i, filter_group, j, data_name, t, ax
def plot_occupancy(
self, grid_width: int, grid_height: int,
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True,
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_occupancy(
grid_width, grid_height, pos_to_index=self.position_to_grid_index, fig=fig, ax=ax,
gaussian_sigma=gaussian_sigma,
intensity_limit=intensity_limit, frame_normalize=frame_normalize, scale_to_one=scale_to_one,
color_bar=not i,
x_label="X (pixels)",
y_label="Y (pixels)" if not i else "",
title="",
)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} occupancy density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_occupancy(
cls, experiments: list["Experiment"], grid_width: int, grid_height: int,
gaussian_sigma: float = 0, intensity_limit: float = 0, title: str = "Subjects occupancy density",
frame_normalize: bool = True, experiment_normalize: bool = True, scale_to_one: bool = True,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
occupancy = np.zeros((grid_width, grid_height))
for experiment in experiments_:
getattr(experiment, data_name).calculate_occupancy(
occupancy, experiment.position_to_grid_index, frame_normalize,
)
if experiment_normalize:
occupancy /= len(experiments_)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
SubjectPos.plot_occupancy_data(
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one,
color_bar=not i and not j,
title="",
x_label="X (pixels)" if i == n_groups - 1 else "",
y_label=f"{group}Y (pixels)" if not j else "",
)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def export_multi_experiment_motion(
cls, experiments: list["Experiment"], filename: Path,
measure: Literal["motion_index", "speed"] = "motion_index", use_chopped_data: bool = False,
) -> None:
header = experiments[0].get_header_id_columns() + ["Stage", "Motion (mean)"]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
if use_chopped_data:
stages = list(experiment.chopped_pos_data.keys())
stages_data = [experiment.chopped_pos_data[stage] for stage in stages]
else:
stages = "Pre", "Trial", "Post"
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
id_columns = experiment.get_id_columns()
for stage, data in zip(stages, stages_data):
if measure == "motion_index":
motion = data.motion_index
mean_motion = np.mean(motion[motion >= 0])
elif measure == "speed":
_, motion, _, _ = data.calculate_speed()
mean_motion = np.mean(motion)
line = id_columns + [stage, mean_motion]
writer.writerow(map(str, line))
def plot_motion_index(
self, y_limit: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = ""
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_motion_index(
fig, ax, **{"y_label": ""} if i else {},
)
if y_limit is not None:
ax.set_ylim(0, y_limit)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} motion index")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_motion(
cls, experiments: list["Experiment"], y_limit: float | None = None, title: str = "Subjects motion index",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
measure: Literal["motion_index", "speed"] = "motion_index",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
for experiment in experiments_:
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
label = "Motion index" if measure == "motion_index" else "Speed (px / s)"
kwargs["y_label"] = f"{group}{label}"
data = getattr(experiment, data_name)
getattr(data, f"plot_{measure}")(fig, ax, **kwargs)
if y_limit is not None:
ax.set_ylim(0, y_limit)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion_index_histogram(
self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "",
hist_range: tuple[float, float] = (0, 3),
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data = data.motion_index
data = data[data >= 0]
ax.hist(data, bins=n_bins, density=True, log=True)
ax.set_xlim(*hist_range)
ax.set_xlabel("Motion index")
if not i:
ax.set_ylabel("Density")
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} motion index density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_motion_index_histogram(
cls, experiments: list["Experiment"], n_bins=100, title: str = "Subjects motion index",
save_fig_root: None | Path = None, save_fig_prefix: str = "", hist_range: tuple[float, float]=(0, 3),
filter_args: list[dict] | None = None, group_label: str = "",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
items = []
for experiment in experiments_:
data = getattr(experiment, data_name).motion_index
items.append(data[data >= 0])
if items:
ax.hist(np.concatenate(items), bins=n_bins, density=True, log=True)
ax.set_xlim(*hist_range)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
if i == n_groups - 1:
ax.set_xlabel("Motion index")
if not j:
ax.set_ylabel(f"{group}Density")
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_speed(
self, y_limit: float | None = None, save_fig_root: None | Path = None,
save_fig_prefix: str = ""
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_speed(
fig, ax, **{"y_label": ""} if i else {},
)
if y_limit is not None:
ax.set_ylim(0, y_limit)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} speed")
save_or_show(save_fig_root, save_fig_prefix)
def plot_speed_histogram(
self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "",
hist_range: tuple[float, float] = (0, 200),
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
_, data, _, _ = data.calculate_speed()
ax.hist(data, bins=n_bins, density=True, log=True)
ax.set_xlim(*hist_range)
ax.set_xlabel("Speed (px / s)")
if not i:
ax.set_ylabel("Density")
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} Speed density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_speed_histogram(
cls, experiments: list["Experiment"], n_bins=100,
title: str = "Subjects Speed (px / s)", hist_range: tuple[float, float]=(0, 200),
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = ""
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
items = []
for experiment in experiments_:
_, data, _, _ = getattr(experiment, data_name).calculate_speed()
items.append(data)
ax.hist(np.concatenate(items), bins=n_bins, density=True, log=True)
ax.set_xlim(*hist_range)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
if i == n_groups - 1:
ax.set_xlabel("Speed (px / s)")
if not j:
ax.set_ylabel(f"{group}Density")
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(f"{title}")
save_or_show(save_fig_root, save_fig_prefix)
def plot_categorical_values(
self, measure: MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for (data, title), ax in zip(self._get_data_items(self), axs.flatten()):
times, _, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, measure_options)
data.plot_categorical_values(times, categoricals_index, categoricals, fig, ax)
ax.set_title(title)
label = self.title_root.format(**self.metadata)
fig.suptitle(f"Subject {label}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_side_of_box(
cls, experiments: list["Experiment"], measure: MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for (data_name, t), ax in zip(cls._get_data_items(), axs.flatten()):
for experiment in experiments:
times, _, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
getattr(experiment, data_name).plot_categorical_values(times, categoricals_index, categoricals, fig, ax)
ax.set_title(t)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def _descriptor_to_point(self, point: tuple[str, str]):
cx, cy = self.box_center
bw, bh = self.box_size
match point[0]:
case "left":
x = cx - bw / 2
case "right":
x = cx + bw / 2
case "center":
x = cx
case _:
raise ValueError(f"Can't recognize {point[0]}")
match point[1]:
case "bottom":
y = cy + bh / 2
case "top":
y = cy - bh / 2
case "center":
y = cy
case _:
raise ValueError(f"Can't recognize {point[1]}")
return x, y
def plot_distance_from_point(
self, point: tuple[str, str], save_fig_root: None | Path = None,
save_fig_prefix: str = "", post_title: str = "distance from teaball corner",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
point_xy = self._descriptor_to_point(point)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_distance_from_point(
point_xy, fig, ax, **{"y_label": ""} if i else {},
)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} {post_title}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_distance_from_point(
cls, experiments: list["Experiment"], point: tuple[str, str],
title: str = "Subjects distance from teaball corner",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
for experiment in experiments_:
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
kwargs["y_label"] = f"{group}Distance (px)"
point_xy = experiment._descriptor_to_point(point)
getattr(experiment, data_name).plot_distance_from_point(point_xy, fig, ax, **kwargs)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def _get_categorical_percent(
cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray],
sorted_categoricals: tuple[str, ...], unit: Literal["percent", "times"]="times",
):
if unit == "times":
percents = np.empty((len(categoricals_index), len(sorted_categoricals)))
for e, (exp_categoricals_index, time_diff) in enumerate(zip(categoricals_index, times_diff)):
if len(exp_categoricals_index) != len(time_diff):
raise ValueError("Provided time diff and categorical index are not the same length")
for i in range(len(sorted_categoricals)):
percents[e, i] = np.sum(time_diff[exp_categoricals_index == i])
elif unit == "percent":
counts = np.array([
[np.sum(arr == i) for i in range(len(sorted_categoricals))]
for arr in categoricals_index
])
percents = counts / np.sum(counts, axis=1, keepdims=True) * 100
else:
raise ValueError(f"Unknown unit {unit}")
mean_prop = np.mean(percents, axis=0).squeeze()
return mean_prop, [prop.squeeze() for prop in percents]
@classmethod
def _plot_percents(
cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray],
sorted_categoricals: tuple[str, ...],
fig: plt.Figure, ax: plt.Axes,
x_label: str = "Teaball side", y_label: str = "% time spent",
unit: Literal["percent", "times"] = "times",
):
mean_prop, percents = cls._get_categorical_percent(categoricals_index, times_diff, sorted_categoricals, unit)
ax.bar(np.arange(len(sorted_categoricals)), mean_prop, tick_label=sorted_categoricals)
if len(categoricals_index) > 1:
for prop in percents:
ax.plot(np.arange(len(sorted_categoricals)), prop.squeeze(), "k.")
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
def plot_categorical_percent(
self, measure: MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
save_fig_root: None | Path = None, save_fig_prefix: str = "",
unit: Literal["percent", "times"] = "times", x_label = "Odor side", action_label: str = "spent in side",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
y_label = "% time spent" if unit == "percent" else "total time spent (s)"
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
_, time_diff, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, measure_options)
self._plot_percents(
[categoricals_index], [time_diff], categoricals, fig, ax, x_label, "" if i else y_label, unit,
)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
tp = "total" if unit == "times" else "%"
fig.suptitle(f"{label} {tp} time {action_label}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_categorical_percent(
cls, experiments: list["Experiment"], measure: MEASURE_TYPE = "occupancy",
measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
unit: Literal["percent", "times"] = "times", x_label: str = "Odor side",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
categoricals_index_all = []
times_all = []
for experiment in experiments_:
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
categoricals_index_all.append(categoricals_index)
times_all.append(times_diff)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": "", "x_label": x_label}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
label = "% time spent" if unit == "percent" else "total time spent (s)"
kwargs["y_label"] = f"{group}{label}"
cls._plot_percents(categoricals_index_all, times_all, categoricals, fig, ax, **kwargs, unit=unit)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def export_multi_experiment_categorical_percent(
cls, experiments: list["Experiment"], filename: Path,
measure: MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
unit: Literal["percent", "times"] = "times",
use_chopped_data: bool = False,
) -> None:
sign = "sec" if unit == "times" else "%"
header = experiments[0].get_header_id_columns() + ["Stage"] + [
f"{label} {sign}" for label in categoricals
]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
if use_chopped_data:
stages = list(experiment.chopped_pos_data.keys())
stages_data = [experiment.chopped_pos_data[stage] for stage in stages]
else:
stages = "Pre", "Trial", "Post"
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
id_columns = experiment.get_id_columns()
for stage, data in zip(stages, stages_data):
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data, measure, categoricals, measure_options
)
mean_prop, _ = cls._get_categorical_percent([categoricals_index], [times_diff], categoricals, unit)
line = id_columns + [stage, *mean_prop]
writer.writerow(map(str, line))
@classmethod
def plot_multi_experiment_merged_by_period_categorical_percent(
cls, experiments: list["Experiment"], filter_args: list[dict],
measure: MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "", group_label: str = "",
unit: Literal["percent", "times"] = "times", x_label: str = "Teaball side",
):
n_groups = len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(1, n_periods, sharey=True, sharex=True)
axes = list(axs.flatten())
bar_width = 1 / (n_groups + 1)
n_categoricals = len(categoricals)
for i, filter_group in enumerate(filter_args):
experiments_ = cls.filter(experiments, **filter_group)
for j, (data_name, t) in enumerate(periods):
categoricals_index_all = []
times_all = []
for experiment in experiments_:
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
categoricals_index_all.append(categoricals_index)
times_all.append(times_diff)
ax = axes[j]
mean_prop, percents = cls._get_categorical_percent(
categoricals_index_all, times_all, categoricals, unit
)
ax.bar(
i * bar_width + np.arange(n_categoricals), mean_prop, bar_width,
label=group_label.format(**filter_group),
)
for prop in percents:
ax.plot(i * bar_width + np.arange(n_categoricals), prop, "k.")
ax.set_xlabel(x_label)
if not j:
label = "% time spent" if unit == "percent" else "total time spent (s)"
ax.set_ylabel(f"{label}")
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
for ax in axes:
ax.set_xticks(np.arange(n_categoricals) + (1 - bar_width) / 2 - bar_width / 2, categoricals)
ax.set_xlim(-bar_width, n_categoricals + bar_width)
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, ncols=n_groups, bbox_to_anchor=(0, 0), loc=2)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_merged_by_group_categorical_percent(
cls, experiments: list["Experiment"], filter_args: list[dict],
measure: MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Teaball-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "", show_titles: bool = True,
show_legend: bool = True, only_categorical: str | None = None, show_xlabel: bool = True,
unit: Literal["percent", "times"] = "times", group_label: str = "", x_label: str = "Teaball side"
):
n_groups = len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, 1, sharey=True, sharex=True)
axes = list(axs.flatten())
bar_width = 1 / (n_periods + 1)
n_categoricals = len(categoricals)
categoricals_s = slice(0, n_categoricals)
n_categoricals_used = n_categoricals
categoricals_used = categoricals
if only_categorical:
i = categoricals.index(only_categorical)
categoricals_s = slice(i, i + 1)
n_categoricals_used = 1
categoricals_used = [only_categorical]
for i, filter_group in enumerate(filter_args):
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
experiments_ = cls.filter(experiments, **filter_group)
for j, (data_name, t) in enumerate(periods):
categoricals_index_all = []
times_all = []
for experiment in experiments_:
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
categoricals_index_all.append(categoricals_index)
times_all.append(times_diff)
ax = axes[i]
mean_prop, percents = cls._get_categorical_percent(
categoricals_index_all, times_all, categoricals, unit
)
ax.bar(j * bar_width + np.arange(n_categoricals_used), mean_prop[categoricals_s], bar_width, label=t)
for prop in percents:
ax.plot(j * bar_width + np.arange(n_categoricals_used), prop[categoricals_s], "k.")
if i == n_groups - 1 and show_xlabel:
ax.set_xlabel(x_label)
label = "% time spent" if unit == "percent" else "total time spent (s)"
ax.set_ylabel(f"{group}{label}")
if not j and show_titles:
ax.set_title(group_label.format(**filter_group))
for ax in axes:
ax.set_xticks(np.arange(n_categoricals_used) + (1 - bar_width) / 2 - bar_width / 2, categoricals_used)
ax.set_xlim(-bar_width, n_categoricals_used - bar_width)
if show_legend:
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, ncols=1, bbox_to_anchor=(0, 0), loc=2)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix, width_inch=3)
@classmethod
def export_multi_experiment_frames(cls, experiments: list["Experiment"], filename: Path) -> None:
header = experiments[0].get_header_id_columns() + [
"total_frames", "pre_frames", "trial_frames", "post_frames", "total_duration", "pre_duration",
"trial_duration", "post_duration",
]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
data_obj = (
experiment.pos_data, experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
)
line = experiment.get_id_columns() + [
*(len(obj.times) for obj in data_obj),
*((obj.times[-1] - obj.times[0]) if len(obj.times) else 0 for obj in data_obj),
]
writer.writerow(map(str, line))
@classmethod
def filter(cls, experiments: list["Experiment"], **metadata):
for key, value in metadata.items():
if value is not None:
experiments = [t for t in experiments if t.metadata[key] == value]
return experiments
@classmethod
def count_motion_index_range(cls, experiments: list["Experiment"]) -> dict[float, int]:
items = []
for exp in experiments:
motion = exp.pos_data.motion_index
items.append(motion[motion >= 0])
data = np.concatenate(items)
n = len(data)
counts = {}
zero_mask = data == 0
counts[0] = np.sum(zero_mask).item()
data = data[np.logical_not(zero_mask)]
if not len(data):
return counts
min_val = np.min(data)
assert min_val > 0
min_order = math.log10(min_val)
for order in range(int(math.ceil(-min_order)) + 1):
val = math.pow(10, -order)
mask = data >= val
counts[val] = np.sum(mask).item()
data = data[np.logical_not(mask)]
assert not len(data)
assert sum(counts.values()) == n
return counts
@dataclass
class SubjectPos:
filename: str | Path
times: np.ndarray
track: np.ndarray = field(repr=False)
motion_index: np.ndarray = field(repr=False)
_downsampled_key: Any = field(default=None, init=False, repr=False)
_downsampled_values: np.ndarray = field(default=None, init=False, repr=False)
_point_marker: dict = field(default_factory=dict, init=False, repr=False)
@property
def min_x(self):
return np.min(self.track[:, 0])
@property
def min_y(self):
return np.min(self.track[:, 1])
@property
def max_x(self):
return np.max(self.track[:, 0])
@property
def max_y(self):
return np.max(self.track[:, 1])
@classmethod
def parse_csv_track(cls, filename: Path, subject_name: str = "mouse", frame_rate: float = 30.) -> "SubjectPos":
track = []
times = []
index = []
print(f"Reading {filename}")
with open(filename, "r") as fh:
reader = csv.reader(fh)
header = list(next(reader))
last_t = None
last_frame_t = None
for i, row in enumerate(reader):
# if i % 2:
# continue
frame, instance, cx, cy, index_val, t1, t2 = row
# frame = int(frame) // 2
if header[-1] == "real_timestamp_sec":
t = t2
elif header[-2] == "real_timestamp_sec":
t = t1
else:
raise ValueError("Cannot find the real_timestamp_sec column")
if instance != subject_name:
continue
track.append((int(float(cx)), int(float(cy))))
index.append(float(index_val))
t = float(t)
frame_t = int(frame) / frame_rate
if not (frame_t * 0.95 <= t <= frame_t * 1.05):
raise ValueError(f"{filename} frame rate of {frame_rate} does not match time stamp at \"{row}\"")
if last_t is not None and last_t >= t or last_frame_t is not None and last_frame_t >= frame_t:
raise ValueError(f"{filename} sequential times are not increasing \"{row}\"")
last_t = t
last_frame_t = frame_t
times.append(frame_t)
times = np.array(times)
return SubjectPos(
filename=filename, times=np.array(times), track=np.array(track), motion_index=np.array(index))
def extract_range(self, t_start: float | None = None, t_end: float | None = None):
if t_start is None:
t_start = self.times[0]
if t_end is None:
t_end = self.times[-1] + 1
i_s = np.sum(self.times < t_start)
i_e = np.sum(self.times <= t_end)
return SubjectPos(
filename=self.filename, times=self.times[i_s:i_e], track=self.track[i_s:i_e],
motion_index=self.motion_index[i_s:i_e]
)
def calculate_occupancy(
self, occupancy: np.ndarray, pos_to_index: Callable | None = None, frame_normalize: bool = True,
) -> None:
n = self.track.shape[0]
if not n:
return
grid_width, grid_height = occupancy.shape
frame_proportion = 1
if frame_normalize:
frame_proportion = 1 / n
for i in range(n):
if self.track[i, 0] < 0 or self.track[i, 1] < 0:
continue
if pos_to_index is None:
x, y = self.track[i, :]
else:
x, y = pos_to_index(*self.track[i, :], grid_width, grid_height)
x = int(min(x, grid_width))
y = int(min(y, grid_height))
occupancy[x, y] += frame_proportion
def calculate_speed(self, downsample_factor: int = 8) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
# key = downsample_factor,
#
# if key != self._downsampled_key or self._downsampled_values is None:
# if len(self.times) <= 27:
# return np.empty(0), np.empty(0)
#
# values = np.concatenate([np.arange(len(self.times))[:, None], self.times[:, None], self.track], axis=1)
# decimated = decimate(values, downsample_factor, axis=0)
#
# self._downsampled_key = key
# self._downsampled_values = decimated
#
# indices = self._downsampled_values[:, 0]
# times = self._downsampled_values[:, 1]
# track = self._downsampled_values[:, 2:]
track = self.track
valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0)
track = track[valid, :]
times = self.times[valid]
elapsed = times[1:] - times[:-1]
if np.any(elapsed <= 0):
raise ValueError(f"Found sequential time stamps whose difference is not positive {self}")
dist = np.sqrt(np.sum(np.square(track[1:, :] - track[:-1, :]), axis=1, keepdims=False))
speed = dist / elapsed
return times[1:], speed, dist, elapsed
@classmethod
def plot_occupancy_data(
cls, occupancy: np.ndarray, fig: plt.Figure, ax: plt.Axes,
gaussian_sigma: float = 0, intensity_limit: float = 0, scale_to_one: bool = True,
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy",
color_bar: bool = True,
):
if gaussian_sigma:
occupancy = gaussian_filter(occupancy, gaussian_sigma)
if scale_to_one:
max_val = occupancy.max()
if max_val:
occupancy /= max_val
else:
occupancy[:] = 0
im = ax.imshow(
occupancy.T, aspect="auto", origin="upper", cmap="viridis", interpolation="sinc",
interpolation_stage="data", vmax=intensity_limit or None
)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if title:
ax.set_title(title)
if color_bar:
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
def plot_occupancy(
self, grid_width: int, grid_height: int,
pos_to_index: Callable | None = None, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True,
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy",
color_bar: bool = True,
):
occupancy = np.zeros((grid_width, grid_height))
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
self.calculate_occupancy(occupancy, pos_to_index, frame_normalize)
self.plot_occupancy_data(
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, x_label, y_label, title, color_bar,
)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion_index(
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Motion index",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
valid = self.motion_index >= 0
times = self.times[valid]
if len(times):
ax.plot(times / 60 - self.times[0] / 60, self.motion_index[valid], **SubjectPos._point_marker)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_speed(
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Speed (px / s)",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
times, speed, _, _ = self.calculate_speed()
if len(times):
ax.plot(times / 60 - self.times[0] / 60, speed, **SubjectPos._point_marker)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def transform_to_categorical_pos(
self, position_to_side: Callable, categoricals: tuple[str, ...],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
track = self.track[:-1, :]
valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0)
pos_categoricals_name = [position_to_side(*p) for p in track[valid, :]]
categoricals = {name: i for i, name in enumerate(categoricals)}
categoricals_index = np.array([categoricals[n] for n in pos_categoricals_name])
time_diff = self.times[1:] - self.times[:-1]
assert np.all(time_diff >= 0)
time_diff = time_diff[valid]
times = self.times[:-1][valid]
return times, time_diff, categoricals_index
def plot_categorical_values(
self, times, categoricals_index, categoricals: tuple[str, ...],
fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
ax.plot(times / 60 - self.times[0] / 60, categoricals_index, "*", alpha=.3, ms=5)
ax.set_yticks(np.arange(len(categoricals)), categoricals)
ax.set_xlabel("Time (min)")
ax.set_ylabel("Side of box relative to teaball")
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_distance_from_point(
self, point_xy: tuple[int, int], fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Distance (px)",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0)
point = np.array(point_xy)[None, :]
distance = np.sqrt(np.sum(np.square(self.track[valid, :] - point), axis=1))
times = self.times[valid]
if len(times):
ax.plot(times / 60 - self.times[0] / 60, distance, **SubjectPos._point_marker)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def export_experiments_csv(
experiments: list[Experiment], data_root: Path, occupancy_meas_args, cue_categoricals,
freezing_index_meas_args, freezing_speed_meas_args, freeze_categoricals, cat_unit, include_chopped: bool
):
Experiment.export_multi_experiment_frames(experiments, data_root / "experiment_metadata.csv")
Experiment.export_multi_experiment_categorical_percent(
experiments, data_root / "box_sides.csv", measure="occupancy", measure_options=occupancy_meas_args,
categoricals=cue_categoricals, unit=cat_unit,
)
Experiment.export_multi_experiment_categorical_percent(
experiments, data_root / f"motion_index_freezing_threshold_{freezing_index_meas_args['threshold']}.csv",
measure="motion_index_freezing",
measure_options=freezing_index_meas_args, categoricals=freeze_categoricals, unit=cat_unit,
)
Experiment.export_multi_experiment_categorical_percent(
experiments, data_root / f"speed_freezing_threshold_{freezing_speed_meas_args['threshold']}.csv",
measure="speed_freezing",
measure_options=freezing_speed_meas_args,
categoricals=freeze_categoricals, unit=cat_unit,
)
Experiment.export_multi_experiment_motion(experiments, data_root / "motion_index_mean.csv", measure="motion_index")
Experiment.export_multi_experiment_motion(experiments, data_root / "speed_mean.csv", measure="speed")
if include_chopped:
Experiment.export_multi_experiment_categorical_percent(
experiments, data_root / "box_sides_pre.csv", use_chopped_data=True,
measure="occupancy", measure_options=occupancy_meas_args,
)
Experiment.export_multi_experiment_motion(
experiments, data_root / "motion_index_mean_pre.csv", use_chopped_data=True,
)
Experiment.export_multi_experiment_motion(
experiments, data_root / "speed_mean_pre.csv", use_chopped_data=True, measure="speed",
)
def export_single_subject_figures(
experiments: list[Experiment], figure_root: Path, grid_width, grid_height, occupancy_meas_args,
cue_point, cat_unit, prefix_pat: str,
):
for experiment in tqdm.tqdm(experiments, desc="subject"):
label = prefix_pat.format(**experiment.metadata)
experiment.plot_occupancy(
grid_width, grid_height, scale_to_one=False, intensity_limit=1e-5,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"intensity_limit100K_{label}",
)
experiment.plot_occupancy(
grid_width, grid_height, intensity_limit=1e-6,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"intensity_limit1M_{label}",
)
experiment.plot_occupancy(
grid_width, grid_height, gaussian_sigma=5,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"occupancy_sigma5_{label}",
)
experiment.plot_occupancy(
grid_width, grid_height, gaussian_sigma=1,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"occupancy_sigma1_{label}",
)
experiment.plot_motion_index_histogram(
n_bins=50, hist_range=(0, 3),
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"motion_index_histogram_{label}",
)
experiment.plot_motion_index(
y_limit=1,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"motion_index_{label}",
)
experiment.plot_speed_histogram(
n_bins=50, hist_range=(0, 800),
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"speed_histogram_{label}",
)
experiment.plot_speed(
y_limit=200,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"speed_{label}",
)
experiment.plot_distance_from_point(
cue_point,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"distance_{label}",
)
experiment.plot_categorical_percent(
measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"side_{label}", x_label="",
)
def export_multi_experiment_figures(
experiments: list[Experiment], figure_root: Path, grid_width, grid_height, occupancy_meas_args,
cat_unit, label: str, cue_categoricals, freezing_index_meas_args, cue_point,
freezing_speed_meas_args, freeze_categoricals,
filter_args: list[dict] | None = None, group_label: str = "", sub_dir: str = "grouped",
):
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height,
# title=f"{label} occupancy density", scale_to_one=False,
# intensity_limit=1e-5, filter_args=filter_args, group_label=group_label,
# save_fig_root=figure_root / sub_dir / "occupancy" / "100K",
# save_fig_prefix=f"occupancy_{label}_intensity_limit100K",
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height, filter_args=filter_args, group_label=group_label,
# title=f"{label} occupancy density", intensity_limit=1e-6,
# save_fig_root=figure_root / sub_dir / "occupancy" / "1M",
# save_fig_prefix=f"occupancy_{label}_intensity_limit1M",
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height, title=f"{label} occupancy", gaussian_sigma=5,
# save_fig_root=figure_root / sub_dir / "occupancy" / "sigma5",
# save_fig_prefix=f"occupancy_{label}_sigma5",
# filter_args=filter_args, group_label=group_label,
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height, title=f"{label} occupancy", gaussian_sigma=1,
# save_fig_root=figure_root / sub_dir / "occupancy" / "sigma1",
# save_fig_prefix=f"occupancy_{label}_sigma1",
# filter_args=filter_args, group_label=group_label,
# )
Experiment.plot_multi_experiment_motion_index_histogram(
experiments, title=f"{label} motion index density", n_bins=50, filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "motion_index_histogram", hist_range=(0, 3),
save_fig_prefix=f"motion_index_histogram_{label}",
)
Experiment.plot_multi_experiment_motion(
experiments, y_limit=1, title=f"{label} motion index", measure="motion_index",
save_fig_root=figure_root / sub_dir / "motion_index",
save_fig_prefix=f"motion_index_{label}",
filter_args=filter_args, group_label=group_label,
)
Experiment.plot_multi_experiment_speed_histogram(
experiments, title=f"{label} speed density", n_bins=50, filter_args=filter_args, group_label=group_label,
hist_range=(0, 800),
save_fig_root=figure_root / sub_dir / "speed_histogram",
save_fig_prefix=f"speed_histogram_{label}",
)
Experiment.plot_multi_experiment_motion(
experiments, y_limit=200, title=f"{label} speed", measure="speed",
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "speed",
save_fig_prefix=f"speed_{label}",
)
# Experiment.plot_multi_experiment_distance_from_point(
# experiments, cue_point, title=f"{label} distance from cue",
# filter_args=filter_args, group_label=group_label,
# save_fig_root=figure_root / sub_dir / "distance", save_fig_prefix=f"distance_{label}",
# )
#
# Experiment.plot_multi_experiment_categorical_percent(
# experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args,
# categoricals=cue_categoricals, unit=cat_unit,
# filter_args=filter_args, group_label=group_label, x_label="",
# save_fig_root=figure_root / sub_dir / "side" / "side",
# save_fig_prefix=f"side_{label}",
# )
Experiment.plot_multi_experiment_categorical_percent(
experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}",
measure="motion_index_freezing",
measure_options=freezing_index_meas_args,
categoricals=freeze_categoricals, unit=cat_unit,
filter_args=filter_args, group_label=group_label, x_label="",
save_fig_root=figure_root / sub_dir / "motion_index_freezing" / "freezing",
save_fig_prefix=f"freezing_{label}",
)
Experiment.plot_multi_experiment_categorical_percent(
experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}",
measure="speed_freezing",
measure_options=freezing_speed_meas_args,
categoricals=freeze_categoricals, unit=cat_unit,
filter_args=filter_args, group_label=group_label, x_label="",
save_fig_root=figure_root / sub_dir / "speed_freezing" / "freezing",
save_fig_prefix=f"freezing_{label}",
)
#
# if not filter_args:
# return
#
# Experiment.plot_multi_experiment_merged_by_period_categorical_percent(
# experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args,
# categoricals=cue_categoricals, x_label="", unit="times",
# filter_args=filter_args, group_label=group_label,
# save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_period",
# save_fig_prefix=f"side_merged_by_period_{label}",
# )
# Experiment.plot_multi_experiment_merged_by_group_categorical_percent(
# experiments, title=f"Total time spent in cue side", show_titles=False, show_legend=True,
# measure="occupancy", measure_options=occupancy_meas_args, unit="times",
# categoricals=cue_categoricals, x_label="Cue side",
# filter_args=filter_args, only_categorical="Cue-side", show_xlabel=False,
# save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_group",
# save_fig_prefix=f"side_merged_by_group_{label}", group_label=group_label,
# )
Experiment.plot_multi_experiment_merged_by_period_categorical_percent(
experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}",
measure="motion_index_freezing",
measure_options=freezing_index_meas_args,
categoricals=freeze_categoricals, x_label="", unit="times",
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "motion_index_freezing" /
"freezing_merged_by_period",
save_fig_prefix=f"freezing_merged_by_period_{label}",
)
Experiment.plot_multi_experiment_merged_by_group_categorical_percent(
experiments,
title=f"Total time spent freezing (motion index). Threshold={freezing_index_meas_args['threshold']}",
show_titles=False, show_legend=True,
measure="motion_index_freezing", measure_options=freezing_index_meas_args, unit="times",
categoricals=freeze_categoricals, x_label="Freezing",
filter_args=filter_args, only_categorical="Freezing", show_xlabel=False,
save_fig_root=figure_root / sub_dir / "motion_index_freezing" /
"freezing_merged_by_group",
save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label,
)
Experiment.plot_multi_experiment_merged_by_period_categorical_percent(
experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}",
measure="speed_freezing",
measure_options=freezing_speed_meas_args,
categoricals=freeze_categoricals, x_label="", unit="times",
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "speed_freezing" /
"freezing_merged_by_period",
save_fig_prefix=f"freezing_merged_by_period_{label}",
)
Experiment.plot_multi_experiment_merged_by_group_categorical_percent(
experiments,
title=f"Total time spent freezing (speed). Threshold={freezing_speed_meas_args['threshold']}",
show_titles=False, show_legend=True,
measure="speed_freezing", measure_options=freezing_speed_meas_args, unit="times",
categoricals=freeze_categoricals, x_label="Freezing",
filter_args=filter_args, only_categorical="Freezing", show_xlabel=False,
save_fig_root=figure_root / sub_dir / "speed_freezing" /
"freezing_merged_by_group",
save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label,
)
def run_yidan():
SubjectPos._point_marker = {"marker": ".", "markersize": 2, "linestyle": "", "markeredgecolor": 'none'}
all_experiments = []
chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(10)]
metadata = ["date", "subject", "strain", "condition", "cell_label", "exposure", "sex"]
filename_fmt = "{date}_{subject}_top_tracked.csv"
Experiment.triplet_name = "Pre Trial", "Trial", "Post Trial"
occupancy_meas_args = {"split_horizontally": True}
freezing_index_meas_args = {"threshold": 0, "freeze_name": "Freezing"}
freezing_speed_meas_args = {"threshold": 0, "freeze_name": "Freezing"}
cue_categoricals = "Cue-side", "Far-side"
freeze_categoricals = "Freezing", "Non-freezing"
root = Path(r'D:\code_data\yidan_2025')
figure_root = root / "results" / "figures"
data_root = root / "results" / "data"
for num in (3, 5, 6, 7):
csv_experiment_times = root / f"behavioral data timestamp - b{num} timestamps.csv"
tracking_data_root = root / f"batch{num}"
json_data_root = root / f"batch{num}" / "json"
experiments = Experiment.parse_experiment_spec_csv(
csv_experiment_times, metadata, filename_fmt, "#{subject} ({exposure} - {date})"
)
all_experiments.extend(experiments)
for experiment in experiments:
if num == 3:
experiment.parse_box_metadata(json_data_root)
else:
experiment.set_box_metadata(200, 132, 784, 486, (960, 600))
experiment.read_pos_track(
tracking_data_root, hab_offset=3 * 60, pre_duration=2 * 60, trial_offset=0, trial_duration=2 * 60,
post_offset=0, post_duration=2 * 60, frame_rate=41.52,
)
experiment.chop_pos_track(hab_segments=chopped_pre_times)
experiments = all_experiments
grid_width, grid_height = Experiment.enlarge_canvas(experiments)
pprint.pprint(Experiment.count_motion_index_range(experiments))
export_experiments_csv(
experiments, data_root, occupancy_meas_args, cue_categoricals, freezing_index_meas_args,
freezing_speed_meas_args, freeze_categoricals, "times", True,
)
# export_single_subject_figures(
# experiments, figure_root, grid_width, grid_height, occupancy_meas_args, ("left", "bottom"), "times",
# "{subject}_exposure-{exposure}_{date}",
# )
conditions = "Blank", "TMT", "2MBA", "IAMM"
strains = "TRAP1",
exposures = "1", "2"
sex = "M", "F"
cond_x_strain_x_expo = [(c, s, e) for c in conditions for s in strains for e in exposures]
for condition, strain, expo in cond_x_strain_x_expo:
experiments_ = Experiment.filter(experiments, condition=condition, exposure=expo, strain=strain)
label = f"{condition}-{expo}-{strain}"
export_multi_experiment_figures(
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label,
cue_categoricals, freezing_index_meas_args, ("left", "bottom"), freezing_speed_meas_args,
freeze_categoricals, None, "", "grouped",
)
filter_args = [{"condition": c} for c in conditions]
items = [
({"exposure": e, "strain": s}, filter_args, "by_exposure") for s in strains for e in exposures
]
filter_args = [{"exposure": e} for e in exposures]
items.extend([
({"condition": c, "strain": s}, filter_args, "by_condition") for c in conditions for s in strains
])
filter_args = [{"exposure": e, "sex": s} for e in exposures for s in sex]
items.extend([
({"condition": c, "strain": s}, filter_args, "by_condition_w_sex") for c in conditions for s in strains
])
for groups, filter_args, sub_dir in items:
experiments_ = Experiment.filter(experiments, **groups)
group_label = "$\\bf" + "\n".join("{{{" + f + "}}}" for f in filter_args[0].keys()) + "$\n\n"
label = "-".join(groups.values())
export_multi_experiment_figures(
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label,
cue_categoricals, freezing_index_meas_args, ("left", "bottom"), freezing_speed_meas_args,
freeze_categoricals, filter_args, group_label, f"grouped_single_figure/{sub_dir}"
)
def run_cynthia():
SubjectPos._point_marker = {"marker": ".", "markersize": 1.5, "linestyle": ""}
chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(8)]
metadata = ["date", "subject", "stage", "condition", "cue", "sex"]
filename_fmt = "{subject}_{stage}_{date}_tracked.csv"
grid_width, grid_height = 641, 480
Experiment.triplet_name = "Pre Trial", "Trial", "Post Trial"
occupancy_meas_args = {"split_horizontally": False}
freezing_index_meas_args = {"threshold": 0, "freeze_name": "Freezing"}
freezing_speed_meas_args = {"threshold": 0, "freeze_name": "Freezing"}
cue_categoricals = "Cue-side", "Far-side"
freeze_categoricals = "Freezing", "Non-freezing"
root = Path(r'D:\code_data\cynthia_sp_2025')
csv_experiment_times = root / "Behavioral timestamps Cynthia.csv"
tracking_data_root = root / "tracking"
figure_root = root / "results" / "figures"
data_root = root / "results" / "data"
experiments = Experiment.parse_experiment_spec_csv(
csv_experiment_times, metadata, filename_fmt, "#{subject} ({stage})",
)
for experiment in experiments:
experiment.set_box_metadata(0, 0, grid_width, grid_height, (grid_width, grid_height))
experiment.read_pos_track(
tracking_data_root, frame_rate=15, pre_duration=-28, post_duration=28,
)
# experiment.chop_pos_track(hab_segments=chopped_pre_times)
pprint.pprint(Experiment.count_motion_index_range(experiments))
# export_experiments_csv(
# experiments, data_root, occupancy_meas_args, cue_categoricals, freezing_index_meas_args,
# freezing_speed_meas_args, freeze_categoricals, "times", False,
# )
# export_single_subject_figures(
# experiments, figure_root, grid_width, grid_height, occupancy_meas_args, ("center", "top"), "times",
# "{subject}_{stage}",
# )
conditions = "control", "cued", "backward"
cues = "odor", "tone"
stages = "train", "test"
cond_x_cue_x_stage = [(c, s, st) for c in conditions for s in cues for st in stages]
for condition, cue, stage in cond_x_cue_x_stage:
experiments_ = Experiment.filter(experiments, condition=condition, cue=cue, stage=stage)
label = f"{condition}-{cue}-{stage}"
export_multi_experiment_figures(
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label,
cue_categoricals, freezing_index_meas_args, ("center", "top"), freezing_speed_meas_args,
freeze_categoricals, None, "", "grouped",
)
filter_args = [{"stage": s} for s in stages]
items = [({"condition": c, "cue": s}, filter_args, "train_vs_test") for c in conditions for s in cues]
filter_args = [{"condition": c} for c in conditions]
items.extend(
[({"cue": c, "stage": s}, filter_args, "by_condition") for c in cues for s in stages]
)
for groups, filter_args, sub_dir in items:
experiments_ = Experiment.filter(experiments, **groups)
group_label = "$\\bf" + "\n".join("{{{" + f + "}}}" for f in filter_args[0].keys()) + "$\n\n"
label = "-".join(groups.values())
export_multi_experiment_figures(
experiments_, figure_root, grid_width, grid_height, occupancy_meas_args, "times", label,
cue_categoricals, freezing_index_meas_args, ("center", "top"), freezing_speed_meas_args,
freeze_categoricals, filter_args, group_label, f"grouped_single_figure/{sub_dir}"
)
if __name__ == "__main__":
# run_yidan()
run_cynthia()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment