Last active
February 11, 2020 17:26
-
-
Save kiyoon/ae84ee3736c1350b20901bfb4a60d621 to your computer and use it in GitHub Desktop.
PyTorch video loader utilising GPU (CUDA) using NVIDIA DALI > 0.18.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nvidia.dali.pipeline import Pipeline | |
from nvidia.dali.plugin import pytorch | |
import nvidia.dali.ops as ops | |
import nvidia.dali.types as types | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--file_list', type=str, default='file_list.txt', | |
help='DALI file_list for VideoReader') | |
parser.add_argument('--frames', type=int, default = 3, | |
help='num frames in input sequence') | |
parser.add_argument('--crop_size', type=int, nargs='+', default=[224, 224], | |
help='[height, width] for input crop') | |
parser.add_argument('--batchsize', type=int, default=1, | |
help='per rank batch size') | |
args = parser.parse_args() | |
class VideoReaderPipeline(Pipeline): | |
def __init__(self, batch_size, sequence_length, num_threads, device_id, file_list, crop_size): | |
super(VideoReaderPipeline, self).__init__(batch_size, num_threads, device_id, seed=12) | |
self.reader = ops.VideoReader(device="gpu", file_list=file_list, sequence_length=sequence_length, normalized=False, | |
random_shuffle=True, image_type=types.RGB, dtype=types.UINT8, initial_fill=16, enable_frame_num=True) | |
self.crop = ops.Crop(device="gpu", crop=crop_size, output_dtype=types.FLOAT) | |
self.uniform = ops.Uniform(range=(0.0, 1.0)) | |
self.coin = ops.CoinFlip(probability=0.5) | |
self.transpose = ops.Transpose(device="gpu", perm=[3, 0, 1, 2]) | |
def define_graph(self): | |
input = self.reader(name="Reader") | |
crop_pos_x = self.uniform() | |
crop_pos_y = self.uniform() | |
cropped = self.crop(input[0], crop_pos_x=crop_pos_x, crop_pos_y=crop_pos_y) | |
is_flipped = self.coin() | |
flipped = self.flip(cropped, horizontal=is_flipped) | |
output = self.transpose(flipped) | |
# Change what you want from the dataloader. | |
# input[1]: label, input[2]: starting frame number indexed from zero | |
return output, input[1], input[2], crop_pos_x, crop_pos_y, is_flipped | |
class DALILoader(): | |
def __init__(self, batch_size, file_list, sequence_length, crop_size): | |
self.pipeline = VideoReaderPipeline(batch_size=batch_size, | |
sequence_length=sequence_length, | |
num_threads=2, | |
device_id=0, | |
file_list=file_list, | |
crop_size=crop_size) | |
self.pipeline.build() | |
self.epoch_size = self.pipeline.epoch_size("Reader") | |
self.dali_iterator = pytorch.DALIGenericIterator(self.pipeline, | |
["data", "label", "frame_num", "crop_pos_x", "crop_pos_y"], | |
self.epoch_size, | |
auto_reset=True) | |
def __len__(self): | |
return int(self.epoch_size) | |
def __iter__(self): | |
return self.dali_iterator.__iter__() | |
def __next__(self): | |
return self.dali_iterator.__next__() | |
if __name__ == "__main__": | |
loader = DALILoader(args.batchsize, | |
args.file_list, | |
args.frames, | |
args.crop_size) | |
batches = len(loader) | |
batch = next(loader) | |
print(batch[0]['data'].shape) | |
print(batch[0]['label']) | |
print(batch[0]['frame_num']) | |
print(batch[0]['crop_pos_x']) | |
print(batch[0]['crop_pos_y']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment