Skip to content

Instantly share code, notes, and snippets.

@talmo
Last active April 7, 2026 17:11
Show Gist options
  • Select an option

  • Save talmo/3fc465d073b5c009a365b0cf66f5756d to your computer and use it in GitHub Desktop.

Select an option

Save talmo/3fc465d073b5c009a365b0cf66f5756d to your computer and use it in GitHub Desktop.
Detect potential track switches in a tracked SLP file and add suggestions for review in SLEAP GUI
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "sleap-io",
# "numpy",
# ]
# ///
"""Detect potential track switches in a tracked SLP file.
This script identifies frames where tracked animals may have swapped identities by
computing motion metrics between consecutive frames. Frames that exceed user-specified
thresholds are added as suggestions to the output SLP file, making it easy to review
them in the SLEAP GUI.
Input:
The input must be an SLP file containing predictions with tracks assigned and
exactly one video. Multi-video SLP files are not supported — split them first or
pass a single-video project.
The file should have at least two tracks; otherwise there are no switches to detect.
Metrics:
Two complementary metrics are available (at least one must be provided):
--velocity Centroid velocity (px/frame). The centroid is the mean position
across all visible nodes for a given track. This catches large
whole-body jumps that typically accompany identity swaps.
--displacement Max single-node displacement (px/frame). The largest Euclidean
distance any individual node moves between consecutive frames.
This is more sensitive to partial jumps where only some body
parts are affected.
Both can be used together — a frame is flagged if *either* threshold is exceeded
for *any* track.
Running:
This script uses inline dependencies (PEP 723) and can be run directly with uv
without any prior installation:
uv run detect_switches.py <input.slp> [options]
Examples:
Flag frames where any track's centroid moves more than 50 px between frames:
uv run detect_switches.py predictions.slp --velocity 50
Flag frames using both metrics with a custom output path:
uv run detect_switches.py predictions.slp --velocity 50 --displacement 80 \
-o predictions.flagged.slp
Delete flagged frames instead of adding suggestions:
uv run detect_switches.py predictions.slp --velocity 50 --delete
The output file defaults to <input>.switches.slp if -o is not specified.
"""
import argparse
from pathlib import Path
import warnings
import numpy as np
import sleap_io as sio
def detect_switches(
labels: sio.Labels,
velocity_threshold: float | None = None,
displacement_threshold: float | None = None,
) -> list[int]:
"""Find frames with potential track switches based on motion thresholds.
Args:
labels: Tracked labels with a single video.
velocity_threshold: Max allowed centroid velocity (px/frame). Centroid is the
mean position across all visible nodes for a given track.
displacement_threshold: Max allowed displacement for any single node between
consecutive frames (px/frame).
Returns:
Sorted list of frame indices that exceed at least one threshold.
"""
if velocity_threshold is None and displacement_threshold is None:
raise ValueError("At least one of velocity or displacement threshold required.")
# (n_frames, n_tracks, n_nodes, 2)
trx = labels.numpy()
n_frames, n_tracks, n_nodes, _ = trx.shape
flagged = set()
# Frame-to-frame differences: (n_frames-1, n_tracks, n_nodes, 2)
diff = np.diff(trx, axis=0)
if velocity_threshold is not None:
# Centroid per frame: (n_frames, n_tracks, 2) — nanmean over nodes
# Suppress warnings for all-NaN slices (tracks not present in every frame)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
centroids = np.nanmean(trx, axis=2)
# Centroid velocity: (n_frames-1, n_tracks)
centroid_diff = np.diff(centroids, axis=0)
velocity = np.sqrt(np.nansum(centroid_diff**2, axis=-1))
# Flag frames where any track exceeds threshold
# velocity[i] is the velocity going INTO frame i+1
frame_mask = np.any(velocity > velocity_threshold, axis=1)
flagged.update(np.where(frame_mask)[0] + 1)
if displacement_threshold is not None:
# Per-node displacement: (n_frames-1, n_tracks, n_nodes)
node_disp = np.sqrt(np.nansum(diff**2, axis=-1))
# Max displacement across nodes: (n_frames-1, n_tracks)
max_disp = np.nanmax(node_disp, axis=2)
frame_mask = np.any(max_disp > displacement_threshold, axis=1)
flagged.update(np.where(frame_mask)[0] + 1)
return sorted(flagged)
def main():
parser = argparse.ArgumentParser(
description=(
"Detect potential track switches in a tracked, single-video SLP file. "
"Flagged frames are added as suggestions to the output SLP for review "
"in the SLEAP GUI."
),
)
parser.add_argument(
"input",
type=str,
help="Path to a tracked SLP file with a single video.",
)
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="Output SLP path. Defaults to input.switches.slp.",
)
parser.add_argument(
"--velocity",
type=float,
default=None,
help="Centroid velocity threshold in px/frame.",
)
parser.add_argument(
"--displacement",
type=float,
default=None,
help="Max single-node displacement threshold in px/frame.",
)
parser.add_argument(
"--delete",
action="store_true",
help="Delete flagged frames instead of adding suggestions.",
)
args = parser.parse_args()
if args.velocity is None and args.displacement is None:
parser.error("Provide at least one of --velocity or --displacement.")
# Load and validate
labels = sio.load_slp(args.input)
if len(labels.videos) == 0:
parser.error(f"No videos found in {args.input}.")
if len(labels.videos) > 1:
parser.error(
f"Expected a single-video SLP file, but found {len(labels.videos)} videos "
f"in {args.input}. Split the file or pass a single-video project."
)
video = labels.video
if len(labels.tracks) < 2:
parser.error(
f"Expected a tracked SLP file with at least 2 tracks, but found "
f"{len(labels.tracks)} in {args.input}. Run tracking first."
)
print(f"Loaded: {args.input}")
print(f" Video: {video.filename}")
print(f" Frames: {len(labels)}, Tracks: {len(labels.tracks)}")
# Detect
flagged = detect_switches(
labels,
velocity_threshold=args.velocity,
displacement_threshold=args.displacement,
)
print(f" Flagged frames: {len(flagged)}")
if args.delete:
# Remove flagged frames from the labels
flagged_set = set(flagged)
labels.labeled_frames = [
lf for lf in labels.labeled_frames if lf.frame_idx not in flagged_set
]
print(f" Deleted {len(flagged)} labeled frames.")
else:
# Add suggestions for flagged frames
for frame_idx in flagged:
labels.suggestions.append(
sio.SuggestionFrame(video=video, frame_idx=frame_idx)
)
# Save
output = args.output or str(Path(args.input).with_suffix(".switches.slp"))
labels.save(output)
print(f"Saved: {output}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment