Skip to content

Instantly share code, notes, and snippets.

@river
Created January 12, 2026 23:12
Show Gist options
  • Select an option

  • Save river/2b07ff0ee7f63cfaf0543b72aa398e12 to your computer and use it in GitHub Desktop.

Select an option

Save river/2b07ff0ee7f63cfaf0543b72aa398e12 to your computer and use it in GitHub Desktop.
"""
Video loading utilities using decord.
Provides efficient video decoding with configurable sampling strategies
and preprocessing for the video encoder.
"""
from enum import Enum
from typing import List, Literal
import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader, cpu
class SamplingStrategy(Enum):
UNIFORM = "uniform"
FIRST_N = "first_n"
LAST_N = "last_n"
RANDOM_N = "random_n"
def load_video_frames(
video_path: str,
n_frames: int = 16,
img_size: int = 224,
sampling_strategy: SamplingStrategy | str = SamplingStrategy.FIRST_N,
frame_stride: int = 2,
is_train: bool = False,
) -> torch.Tensor:
"""
Load and preprocess video frames for the video encoder.
Args:
video_path: Path to the video file.
n_frames: Number of frames to sample from the video.
img_size: Size to resize the frames to.
sampling_strategy: Strategy for sampling frames.
- uniform: uniformly sample n_frames across the video
- first_n: take the first n_frames with stride
- last_n: take the last n_frames with stride
- random_n: randomly sample n_frames from the video (NOTE: only if is_train is True;
otherwise, will be converted to first_n for deterministic inference)
frame_stride: Stride for sampling frames (i.e. sample every frame_stride frames).
is_train: If False, random_n sampling strategy will be converted to first_n.
Returns:
Tensor of shape (C, T, H, W) with dtype torch.uint8, [0, 255]
"""
# Use decord for video decoding
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
# Get frame indices based on sampling strategy
if isinstance(sampling_strategy, str):
sampling_strategy = SamplingStrategy(sampling_strategy)
if sampling_strategy == SamplingStrategy.RANDOM_N and not is_train:
sampling_strategy = SamplingStrategy.FIRST_N
indices = get_frame_indices(
total_frames=total_frames,
n_frames=n_frames,
sampling_strategy=sampling_strategy,
frame_stride=frame_stride,
)
# Load frames - decord returns NDArray with shape (T, H, W, C) in uint8
frames = vr.get_batch(indices).asnumpy() # (T, H, W, C), uint8
# Convert to torch tensor and rearrange to (T, C, H, W)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (T, C, H, W)
# Convert to float (keep in [0, 255] range for normalization stats)
frames = frames.float()
# Resize if needed
_, _, H, W = frames.shape
if H != img_size or W != img_size:
frames = F.interpolate(
frames, size=(img_size, img_size), mode="bilinear", align_corners=False
)
# Rearrange to (C, T, H, W)
frames = frames.permute(1, 0, 2, 3)
return frames
def get_frame_indices(
total_frames: int,
n_frames: int,
sampling_strategy: SamplingStrategy,
frame_stride: int = 1,
) -> List[int]:
"""
Generate frame indices based on sampling strategy.
"""
if total_frames <= 0:
raise ValueError(f"empty video with {total_frames=}")
if n_frames <= 0 or frame_stride <= 0:
raise ValueError(f"{n_frames=} and {frame_stride=}; both must be positive")
match sampling_strategy:
case SamplingStrategy.UNIFORM:
if total_frames <= n_frames:
indices = list(range(total_frames))
else:
indices = np.linspace(0, total_frames - 1, n_frames, dtype=int).tolist()
case SamplingStrategy.FIRST_N:
end = min(n_frames * frame_stride, total_frames)
indices = list(range(0, end, frame_stride))
case SamplingStrategy.LAST_N:
start = max(0, total_frames - n_frames * frame_stride)
indices = list(range(start, total_frames, frame_stride))
case SamplingStrategy.RANDOM_N:
max_start = total_frames - (n_frames - 1) * frame_stride - 1
if max_start <= 0:
indices = list(range(0, total_frames, frame_stride))
else:
start = np.random.default_rng().integers(0, max_start + 1)
indices = list(range(start, start + n_frames * frame_stride, frame_stride))
# Pad or truncate to exactly n_frames
if len(indices) < n_frames:
pad_value = indices[-1] if indices else 0
indices += [pad_value] * (n_frames - len(indices))
return indices[:n_frames]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment