Skip to content

Instantly share code, notes, and snippets.

@innat
Last active October 24, 2023 15:42
Show Gist options
  • Save innat/ce03a6c1642a48108d770f7c19d3363f to your computer and use it in GitHub Desktop.
Save innat/ce03a6c1642a48108d770f7c19d3363f to your computer and use it in GitHub Desktop.

Note

While writing a video data to tfrecord format, the output tfrecord file size would be much larger than the original video file. For quick demonstration purpose, some may use frame step to encode the frame to keep the overal size minimal. But in actual case (research or project) all frame should be considered while encoding to tfrecord. By doing so, while using the tfrecord in the training time, we can sample frames with different indices. Check this discussion. The following code is tested in tf 2.12.

video data layout

Let's say, we have a video data set in the following format.

+--root/
|  +--class_name1/
|  |  +--a.avi
|  |  +--a.avi
|  +--class_name2/
|  |  +--a.mp4
|  |  +--a.avi
|  +--class_name3/
|  |  +--a.avi
|  |  +--a.avi

data reading utility

To read the video data, we will use decord. You can simply install it with pip.

import tensorflow as tf
from decord import VideoReader

input_size = 224

def read_video(file_path):
    vr = VideoReader(file_path)
    frames = vr.get_batch(range(len(vr))).asnumpy()
    return format_frames(
        frames, 
        output_size=(input_size, input_size)
    )

def format_frames(frame, output_size):
    frame = tf.image.resize(
        frame, size=list(output_size)
    )
    return frame

sampling strategies

This method will sample frame from a video. For example, if a video is 200, h, w, 3 and if num_samples=16 then this method will return video with 16, h, w, 3. To see how the frame indices are computed shown below.

def uniform_temporal_subsample(frames, num_samples, temporal_dim=-4):
    """
    Uniformly subsamples num_samples indices from the temporal dimension of the video.
    When num_samples is larger than the size of temporal dimension of the video, it
    will sample frames based on nearest neighbor interpolation.
    Args:
        x (tf.Tensor): A video tensor with dimensions larger than one.
        num_samples (int): The number of equispaced samples to be selected.
        temporal_dim (int): Dimension of temporal to perform temporal subsample.
    Returns:
        An x-like Tensor with subsampled temporal dimension.
        
    https://gist.github.com/innat/205075992360d8d7a241c7f1013866a8
    """
    t = tf.shape(frames)[temporal_dim]
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = tf.linspace(0.0, tf.cast(t - 1, tf.float32), num_samples)
    indices = tf.clip_by_value(indices, 0, tf.cast(t - 1, tf.float32))
    indices = tf.cast(tf.round(indices), tf.int32)
    return tf.gather(frames, indices, axis=temporal_dim)

tfrecord utility

def process_record(videos, labels):
    seq_example = tf.train.SequenceExample()
    videos = tf.cast(videos, dtype=tf.uint8)
    num_frames = videos.shape[0]
    height, width = videos[0].shape[:2]
    
    video_frames_feature_list = seq_example.feature_lists.feature_list.get_or_create('video_frames')
    for example in videos:
        jpeg_example = tf.io.encode_jpeg(example).numpy()
        feature = video_frames_feature_list.feature.add()
        feature.bytes_list.value.append(jpeg_example)
        
    context_features = seq_example.context.feature
    context_features['video/num_frames'].int64_list.value.append(num_frames)
    context_features['video/frame/height'].int64_list.value.append(height)
    context_features['video/frame/width'].int64_list.value.append(width)
    context_features['video/class/label'].int64_list.value.append(labels)
    return seq_example

write methods

def write_tfrecord(
    paths, label_map, output_path, set_name, num_frames
):
    "videos 2 tfrecord"
    tfrec_options = tf.io.TFRecordOptions(compression_type='GZIP')
    tfrec_path = output_path + f'/{set_name}.tfrec'
    
    with tf.io.TFRecordWriter(tfrec_path, tfrec_options) as writer:
        for path in tqdm(paths, desc=f'Writing {set_name} set to TFRecord.'):
            
            # get class name follwed by class id
            class_label = os.path.basename(os.path.dirname(path))
            class_id = label_map[class_label]
            
            # get frames
            frames = read_video(path)
            # uniform_temporal_sampling is for demonstration cause
            # ideally all frame should be saved in actual research work.
            frames = uniform_temporal_subsample(frames, num_frames, temporal_dim=0)

            # encoded the data to tfrecord.
            example = process_record(frames, class_id)
            
            # write it up.
            writer.write(example.SerializeToString())

A high level function that gathers the necessary meta files and invoke write_tfrecord method.

def run_process(
    dataset_path='train', num_frames=16, class_sampling=None,
):
    # 1
    video_dir = f'/root/user/{dataset_path}' 
    class_folders = os.listdir(video_dir)
    class_folders = class_folders if class_sampling is None else class_folders[:class_sampling]
    class_folders_map = {label:i for i, label in enumerate(class_folders)}

    # 2
    output_dir = 'tfoutput' 
    label_map_path = 'label_map.json' 
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    with open(label_map_path, 'w') as f:
        json.dump(class_folders_map, f)
    with open(label_map_path, 'r') as f:
        label_map = json.load(f)
    print('Found labels \n', label_map)
    
    # 3
    files=[]
    formats=['.mp4', '.avi', '.mkv', '.webm', '.mov']
    for ext in formats:
        files.extend(
            glob.glob(os.path.join(video_dir, '**', '*' + ext), recursive=True)
        )
    print('total files count ', len(files))
    
    # 4
    filtered_files = [
        f for f in files if any(sub_folder in f for sub_folder in class_folders)
    ]
    files=filtered_files
    print('total filtered_files count ', len(filtered_files))

    # 5
    if dataset_path == 'train':
        print('data set is for training - shuffle')
        np.random.shuffle(files)
        
    # 6
    write_tfrecord(
        files, 
        label_map, 
        output_dir, 
        dataset_path, 
        num_frames=num_frames, 
    )  
run_process(dataset_path='train', num_frames=16, class_sampling=None)
run_process(dataset_path='test', num_frames=16, class_sampling=None)

image

Parsing TFRecord

After encoding, we can parse the .tfrec file in the following way.

num_frames = 16
input_size = 224


def parse_record(example_proto):
    # Define the features in the context
    context_features = {
        'video/frame/height': tf.io.FixedLenFeature([], tf.int64),
        'video/frame/width' : tf.io.FixedLenFeature([], tf.int64),
        'video/num_frames'  : tf.io.FixedLenFeature([], tf.int64),
        'video/class/label' : tf.io.FixedLenFeature([], tf.int64),
    }
    # Define the features in the feature lists
    sequence_features = {
        'video_frames': tf.io.FixedLenSequenceFeature([], dtype=tf.string)
    }
    
    # Parse the input tf.Example proto using the above dictionaries
    context_parsed, sequence_parsed = tf.io.parse_single_sequence_example(
        serialized=example_proto,
        context_features=context_features,
        sequence_features=sequence_features
    )
    
    # read the data-array
    video_frames = tf.map_fn(
        lambda x: tf.io.decode_jpeg(x, channels=3),
        sequence_parsed['video_frames'],
        dtype=tf.uint8
    )
    video_labels = tf.concat(
        context_parsed['video/class/label'], axis=0
    )
    video_frames = tf.reshape(
        video_frames, [num_frames, input_size, input_size, 3]
    )
    video_frames = tf.cast(video_frames, dtype='float32')
    video_labels = tf.cast(video_labels, dtype='float32')
    
    return video_frames, video_labels
raw_dataset = tf.data.TFRecordDataset(
    "train.tfrec", compression_type='GZIP'
)
parsed_dataset = raw_dataset.map(parse_record).shuffle(8).batch(4)
features, labels = next(iter(parsed_dataset))
features.shape, labels.shape
# (TensorShape([4, 16, 224, 224, 3]), TensorShape([4]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment