Last active
April 7, 2026 17:11
-
-
Save talmo/3fc465d073b5c009a365b0cf66f5756d to your computer and use it in GitHub Desktop.
Detect potential track switches in a tracked SLP file and add suggestions for review in SLEAP GUI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.9" | |
| # dependencies = [ | |
| # "sleap-io", | |
| # "numpy", | |
| # ] | |
| # /// | |
| """Detect potential track switches in a tracked SLP file. | |
| This script identifies frames where tracked animals may have swapped identities by | |
| computing motion metrics between consecutive frames. Frames that exceed user-specified | |
| thresholds are added as suggestions to the output SLP file, making it easy to review | |
| them in the SLEAP GUI. | |
| Input: | |
| The input must be an SLP file containing predictions with tracks assigned and | |
| exactly one video. Multi-video SLP files are not supported — split them first or | |
| pass a single-video project. | |
| The file should have at least two tracks; otherwise there are no switches to detect. | |
| Metrics: | |
| Two complementary metrics are available (at least one must be provided): | |
| --velocity Centroid velocity (px/frame). The centroid is the mean position | |
| across all visible nodes for a given track. This catches large | |
| whole-body jumps that typically accompany identity swaps. | |
| --displacement Max single-node displacement (px/frame). The largest Euclidean | |
| distance any individual node moves between consecutive frames. | |
| This is more sensitive to partial jumps where only some body | |
| parts are affected. | |
| Both can be used together — a frame is flagged if *either* threshold is exceeded | |
| for *any* track. | |
| Running: | |
| This script uses inline dependencies (PEP 723) and can be run directly with uv | |
| without any prior installation: | |
| uv run detect_switches.py <input.slp> [options] | |
| Examples: | |
| Flag frames where any track's centroid moves more than 50 px between frames: | |
| uv run detect_switches.py predictions.slp --velocity 50 | |
| Flag frames using both metrics with a custom output path: | |
| uv run detect_switches.py predictions.slp --velocity 50 --displacement 80 \ | |
| -o predictions.flagged.slp | |
| Delete flagged frames instead of adding suggestions: | |
| uv run detect_switches.py predictions.slp --velocity 50 --delete | |
| The output file defaults to <input>.switches.slp if -o is not specified. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import warnings | |
| import numpy as np | |
| import sleap_io as sio | |
| def detect_switches( | |
| labels: sio.Labels, | |
| velocity_threshold: float | None = None, | |
| displacement_threshold: float | None = None, | |
| ) -> list[int]: | |
| """Find frames with potential track switches based on motion thresholds. | |
| Args: | |
| labels: Tracked labels with a single video. | |
| velocity_threshold: Max allowed centroid velocity (px/frame). Centroid is the | |
| mean position across all visible nodes for a given track. | |
| displacement_threshold: Max allowed displacement for any single node between | |
| consecutive frames (px/frame). | |
| Returns: | |
| Sorted list of frame indices that exceed at least one threshold. | |
| """ | |
| if velocity_threshold is None and displacement_threshold is None: | |
| raise ValueError("At least one of velocity or displacement threshold required.") | |
| # (n_frames, n_tracks, n_nodes, 2) | |
| trx = labels.numpy() | |
| n_frames, n_tracks, n_nodes, _ = trx.shape | |
| flagged = set() | |
| # Frame-to-frame differences: (n_frames-1, n_tracks, n_nodes, 2) | |
| diff = np.diff(trx, axis=0) | |
| if velocity_threshold is not None: | |
| # Centroid per frame: (n_frames, n_tracks, 2) — nanmean over nodes | |
| # Suppress warnings for all-NaN slices (tracks not present in every frame) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", RuntimeWarning) | |
| centroids = np.nanmean(trx, axis=2) | |
| # Centroid velocity: (n_frames-1, n_tracks) | |
| centroid_diff = np.diff(centroids, axis=0) | |
| velocity = np.sqrt(np.nansum(centroid_diff**2, axis=-1)) | |
| # Flag frames where any track exceeds threshold | |
| # velocity[i] is the velocity going INTO frame i+1 | |
| frame_mask = np.any(velocity > velocity_threshold, axis=1) | |
| flagged.update(np.where(frame_mask)[0] + 1) | |
| if displacement_threshold is not None: | |
| # Per-node displacement: (n_frames-1, n_tracks, n_nodes) | |
| node_disp = np.sqrt(np.nansum(diff**2, axis=-1)) | |
| # Max displacement across nodes: (n_frames-1, n_tracks) | |
| max_disp = np.nanmax(node_disp, axis=2) | |
| frame_mask = np.any(max_disp > displacement_threshold, axis=1) | |
| flagged.update(np.where(frame_mask)[0] + 1) | |
| return sorted(flagged) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Detect potential track switches in a tracked, single-video SLP file. " | |
| "Flagged frames are added as suggestions to the output SLP for review " | |
| "in the SLEAP GUI." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "input", | |
| type=str, | |
| help="Path to a tracked SLP file with a single video.", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| type=str, | |
| default=None, | |
| help="Output SLP path. Defaults to input.switches.slp.", | |
| ) | |
| parser.add_argument( | |
| "--velocity", | |
| type=float, | |
| default=None, | |
| help="Centroid velocity threshold in px/frame.", | |
| ) | |
| parser.add_argument( | |
| "--displacement", | |
| type=float, | |
| default=None, | |
| help="Max single-node displacement threshold in px/frame.", | |
| ) | |
| parser.add_argument( | |
| "--delete", | |
| action="store_true", | |
| help="Delete flagged frames instead of adding suggestions.", | |
| ) | |
| args = parser.parse_args() | |
| if args.velocity is None and args.displacement is None: | |
| parser.error("Provide at least one of --velocity or --displacement.") | |
| # Load and validate | |
| labels = sio.load_slp(args.input) | |
| if len(labels.videos) == 0: | |
| parser.error(f"No videos found in {args.input}.") | |
| if len(labels.videos) > 1: | |
| parser.error( | |
| f"Expected a single-video SLP file, but found {len(labels.videos)} videos " | |
| f"in {args.input}. Split the file or pass a single-video project." | |
| ) | |
| video = labels.video | |
| if len(labels.tracks) < 2: | |
| parser.error( | |
| f"Expected a tracked SLP file with at least 2 tracks, but found " | |
| f"{len(labels.tracks)} in {args.input}. Run tracking first." | |
| ) | |
| print(f"Loaded: {args.input}") | |
| print(f" Video: {video.filename}") | |
| print(f" Frames: {len(labels)}, Tracks: {len(labels.tracks)}") | |
| # Detect | |
| flagged = detect_switches( | |
| labels, | |
| velocity_threshold=args.velocity, | |
| displacement_threshold=args.displacement, | |
| ) | |
| print(f" Flagged frames: {len(flagged)}") | |
| if args.delete: | |
| # Remove flagged frames from the labels | |
| flagged_set = set(flagged) | |
| labels.labeled_frames = [ | |
| lf for lf in labels.labeled_frames if lf.frame_idx not in flagged_set | |
| ] | |
| print(f" Deleted {len(flagged)} labeled frames.") | |
| else: | |
| # Add suggestions for flagged frames | |
| for frame_idx in flagged: | |
| labels.suggestions.append( | |
| sio.SuggestionFrame(video=video, frame_idx=frame_idx) | |
| ) | |
| # Save | |
| output = args.output or str(Path(args.input).with_suffix(".switches.slp")) | |
| labels.save(output) | |
| print(f"Saved: {output}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment