Created
January 12, 2026 23:12
-
-
Save river/2b07ff0ee7f63cfaf0543b72aa398e12 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Video loading utilities using decord. | |
| Provides efficient video decoding with configurable sampling strategies | |
| and preprocessing for the video encoder. | |
| """ | |
| from enum import Enum | |
| from typing import List, Literal | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from decord import VideoReader, cpu | |
| class SamplingStrategy(Enum): | |
| UNIFORM = "uniform" | |
| FIRST_N = "first_n" | |
| LAST_N = "last_n" | |
| RANDOM_N = "random_n" | |
| def load_video_frames( | |
| video_path: str, | |
| n_frames: int = 16, | |
| img_size: int = 224, | |
| sampling_strategy: SamplingStrategy | str = SamplingStrategy.FIRST_N, | |
| frame_stride: int = 2, | |
| is_train: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Load and preprocess video frames for the video encoder. | |
| Args: | |
| video_path: Path to the video file. | |
| n_frames: Number of frames to sample from the video. | |
| img_size: Size to resize the frames to. | |
| sampling_strategy: Strategy for sampling frames. | |
| - uniform: uniformly sample n_frames across the video | |
| - first_n: take the first n_frames with stride | |
| - last_n: take the last n_frames with stride | |
| - random_n: randomly sample n_frames from the video (NOTE: only if is_train is True; | |
| otherwise, will be converted to first_n for deterministic inference) | |
| frame_stride: Stride for sampling frames (i.e. sample every frame_stride frames). | |
| is_train: If False, random_n sampling strategy will be converted to first_n. | |
| Returns: | |
| Tensor of shape (C, T, H, W) with dtype torch.uint8, [0, 255] | |
| """ | |
| # Use decord for video decoding | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| total_frames = len(vr) | |
| # Get frame indices based on sampling strategy | |
| if isinstance(sampling_strategy, str): | |
| sampling_strategy = SamplingStrategy(sampling_strategy) | |
| if sampling_strategy == SamplingStrategy.RANDOM_N and not is_train: | |
| sampling_strategy = SamplingStrategy.FIRST_N | |
| indices = get_frame_indices( | |
| total_frames=total_frames, | |
| n_frames=n_frames, | |
| sampling_strategy=sampling_strategy, | |
| frame_stride=frame_stride, | |
| ) | |
| # Load frames - decord returns NDArray with shape (T, H, W, C) in uint8 | |
| frames = vr.get_batch(indices).asnumpy() # (T, H, W, C), uint8 | |
| # Convert to torch tensor and rearrange to (T, C, H, W) | |
| frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (T, C, H, W) | |
| # Convert to float (keep in [0, 255] range for normalization stats) | |
| frames = frames.float() | |
| # Resize if needed | |
| _, _, H, W = frames.shape | |
| if H != img_size or W != img_size: | |
| frames = F.interpolate( | |
| frames, size=(img_size, img_size), mode="bilinear", align_corners=False | |
| ) | |
| # Rearrange to (C, T, H, W) | |
| frames = frames.permute(1, 0, 2, 3) | |
| return frames | |
| def get_frame_indices( | |
| total_frames: int, | |
| n_frames: int, | |
| sampling_strategy: SamplingStrategy, | |
| frame_stride: int = 1, | |
| ) -> List[int]: | |
| """ | |
| Generate frame indices based on sampling strategy. | |
| """ | |
| if total_frames <= 0: | |
| raise ValueError(f"empty video with {total_frames=}") | |
| if n_frames <= 0 or frame_stride <= 0: | |
| raise ValueError(f"{n_frames=} and {frame_stride=}; both must be positive") | |
| match sampling_strategy: | |
| case SamplingStrategy.UNIFORM: | |
| if total_frames <= n_frames: | |
| indices = list(range(total_frames)) | |
| else: | |
| indices = np.linspace(0, total_frames - 1, n_frames, dtype=int).tolist() | |
| case SamplingStrategy.FIRST_N: | |
| end = min(n_frames * frame_stride, total_frames) | |
| indices = list(range(0, end, frame_stride)) | |
| case SamplingStrategy.LAST_N: | |
| start = max(0, total_frames - n_frames * frame_stride) | |
| indices = list(range(start, total_frames, frame_stride)) | |
| case SamplingStrategy.RANDOM_N: | |
| max_start = total_frames - (n_frames - 1) * frame_stride - 1 | |
| if max_start <= 0: | |
| indices = list(range(0, total_frames, frame_stride)) | |
| else: | |
| start = np.random.default_rng().integers(0, max_start + 1) | |
| indices = list(range(start, start + n_frames * frame_stride, frame_stride)) | |
| # Pad or truncate to exactly n_frames | |
| if len(indices) < n_frames: | |
| pad_value = indices[-1] if indices else 0 | |
| indices += [pad_value] * (n_frames - len(indices)) | |
| return indices[:n_frames] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment