Last active
August 30, 2023 17:36
-
-
Save innat/205075992360d8d7a241c7f1013866a8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import tensorflow as tf | |
from keras import layers | |
def uniform_temporal_subsample(x, num_samples, temporal_dim=-4): | |
""" | |
Uniformly subsamples num_samples indices from the temporal dimension of the video. | |
When num_samples is larger than the size of temporal dimension of the video, it | |
will sample frames based on nearest neighbor interpolation. | |
Args: | |
x (tf.Tensor): A video tensor with dimensions larger than one. | |
num_samples (int): The number of equispaced samples to be selected. | |
temporal_dim (int): Dimension of temporal to perform temporal subsample. | |
Returns: | |
An x-like Tensor with subsampled temporal dimension. | |
""" | |
t = tf.shape(x)[temporal_dim] | |
# Sample by nearest neighbor interpolation if num_samples > t. | |
indices = tf.linspace(0.0, tf.cast(t - 1, tf.float32), num_samples) | |
indices = tf.clip_by_value(indices, 0, tf.cast(t - 1, tf.float32)) | |
indices = tf.cast(tf.round(indices), tf.int32) | |
return tf.gather(x, indices, axis=temporal_dim) | |
def uniform_temporal_subsample_repeated(frames, frame_ratios, temporal_dim=-3): | |
""" | |
Prepare output as a list of tensors subsampled from the input frames. Each tensor | |
maintains a unique copy of subsampled frames, which corresponds to a unique | |
pathway. | |
Args: | |
frames (tf.Tensor): frames of images sampled from the video. Expected to have | |
tf.Tensor with dimension larger than one. | |
frame_ratios (list): ratio to perform temporal down-sampling for each pathways. | |
temporal_dim (int): dimension of temporal. | |
Returns: | |
frame_list (list): list of tensors as output. | |
""" | |
temporal_length = tf.shape(frames)[temporal_dim] | |
frame_list = [] | |
for ratio in frame_ratios: | |
pathway = uniform_temporal_subsample( | |
frames, temporal_length // ratio, temporal_dim | |
) | |
frame_list.append(pathway) | |
return frame_list | |
class UniformTemporalSubsample(layers.Layer): | |
def __init__(self, num_samples: int, temporal_dim: int = -4): | |
super().__init__() | |
self._num_samples = num_samples | |
self._temporal_dim = temporal_dim | |
def call(self, inputs): | |
return uniform_temporal_subsample( | |
inputs, | |
num_samples=self._num_samples, | |
temporal_dim=self._temporal_dim | |
) | |
class UniformTemporalSubsampleRepeated(layers.Layer): | |
def __init__(self, frame_ratios: Tuple[int], temporal_dim: int = -4): | |
super().__init__() | |
self._frame_ratios = frame_ratios | |
self._temporal_dim = temporal_dim | |
def call(self, inputs): | |
return uniform_temporal_subsample_repeated( | |
inputs, | |
frame_ratios=self._frame_ratios, | |
temporal_dim=self._temporal_dim | |
) | |
x = tf.ones(shape=(2, 16, 224, 224, 3)) | |
layer1 = UniformTemporalSubsample(num_samples=8, temporal_dim=1) | |
layer2 = UniformTemporalSubsampleRepeated(frame_ratios=[2, 3, 5], temporal_dim=1) | |
y1 = layer1(x) | |
y2 = layer2(x) | |
print(y1.shape) | |
[i.shape for i in y2] | |
# (2, 8, 224, 224, 3) | |
# [TensorShape([2, 8, 224, 224, 3]), | |
# TensorShape([2, 5, 224, 224, 3]), | |
# TensorShape([2, 3, 224, 224, 3])] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment