Skip to content

Instantly share code, notes, and snippets.

@innat
Last active August 30, 2023 17:36
Show Gist options
  • Save innat/205075992360d8d7a241c7f1013866a8 to your computer and use it in GitHub Desktop.
Save innat/205075992360d8d7a241c7f1013866a8 to your computer and use it in GitHub Desktop.
from typing import Tuple
import tensorflow as tf
from keras import layers
def uniform_temporal_subsample(x, num_samples, temporal_dim=-4):
"""
Uniformly subsamples num_samples indices from the temporal dimension of the video.
When num_samples is larger than the size of temporal dimension of the video, it
will sample frames based on nearest neighbor interpolation.
Args:
x (tf.Tensor): A video tensor with dimensions larger than one.
num_samples (int): The number of equispaced samples to be selected.
temporal_dim (int): Dimension of temporal to perform temporal subsample.
Returns:
An x-like Tensor with subsampled temporal dimension.
"""
t = tf.shape(x)[temporal_dim]
# Sample by nearest neighbor interpolation if num_samples > t.
indices = tf.linspace(0.0, tf.cast(t - 1, tf.float32), num_samples)
indices = tf.clip_by_value(indices, 0, tf.cast(t - 1, tf.float32))
indices = tf.cast(tf.round(indices), tf.int32)
return tf.gather(x, indices, axis=temporal_dim)
def uniform_temporal_subsample_repeated(frames, frame_ratios, temporal_dim=-3):
"""
Prepare output as a list of tensors subsampled from the input frames. Each tensor
maintains a unique copy of subsampled frames, which corresponds to a unique
pathway.
Args:
frames (tf.Tensor): frames of images sampled from the video. Expected to have
tf.Tensor with dimension larger than one.
frame_ratios (list): ratio to perform temporal down-sampling for each pathways.
temporal_dim (int): dimension of temporal.
Returns:
frame_list (list): list of tensors as output.
"""
temporal_length = tf.shape(frames)[temporal_dim]
frame_list = []
for ratio in frame_ratios:
pathway = uniform_temporal_subsample(
frames, temporal_length // ratio, temporal_dim
)
frame_list.append(pathway)
return frame_list
class UniformTemporalSubsample(layers.Layer):
def __init__(self, num_samples: int, temporal_dim: int = -4):
super().__init__()
self._num_samples = num_samples
self._temporal_dim = temporal_dim
def call(self, inputs):
return uniform_temporal_subsample(
inputs,
num_samples=self._num_samples,
temporal_dim=self._temporal_dim
)
class UniformTemporalSubsampleRepeated(layers.Layer):
def __init__(self, frame_ratios: Tuple[int], temporal_dim: int = -4):
super().__init__()
self._frame_ratios = frame_ratios
self._temporal_dim = temporal_dim
def call(self, inputs):
return uniform_temporal_subsample_repeated(
inputs,
frame_ratios=self._frame_ratios,
temporal_dim=self._temporal_dim
)
x = tf.ones(shape=(2, 16, 224, 224, 3))
layer1 = UniformTemporalSubsample(num_samples=8, temporal_dim=1)
layer2 = UniformTemporalSubsampleRepeated(frame_ratios=[2, 3, 5], temporal_dim=1)
y1 = layer1(x)
y2 = layer2(x)
print(y1.shape)
[i.shape for i in y2]
# (2, 8, 224, 224, 3)
# [TensorShape([2, 8, 224, 224, 3]),
# TensorShape([2, 5, 224, 224, 3]),
# TensorShape([2, 3, 224, 224, 3])]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment