Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created November 14, 2022 13:16
Show Gist options
  • Save chenyaofo/bd4e830facb5d4db3444c741ef2e269f to your computer and use it in GitHub Desktop.
Save chenyaofo/bd4e830facb5d4db3444c741ef2e269f to your computer and use it in GitHub Desktop.
A benchmark for image preprocessing.
import time
import os
import warnings
import pathlib
try:
import nvidia.dali.types as types
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.tfrecord as tfrec
except ImportError:
warnings.warn("NVIDIA DALI library is unavailable, cannot load and preprocess dataset with DALI.")
try:
import webdataset as wds
except ImportError:
warnings.warn("Webdataset library is unavailable, cannot load dataset with webdataset.")
import torchvision.transforms as T
def glob_by_suffix(path, pattern):
tars = list(map(str, pathlib.Path(path).glob(pattern)))
tars = sorted(tars)
return tars
def get_train_transforms(crop_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], is_training=True):
pipelines = []
if is_training:
pipelines.append(T.RandomResizedCrop(crop_size))
pipelines.append(T.RandomHorizontalFlip())
else:
pipelines.append(T.Resize(int(crop_size/7*8)))
pipelines.append(T.CenterCrop(crop_size))
pipelines.append(T.ToTensor())
pipelines.append(T.Normalize(mean=mean, std=std))
return T.Compose(pipelines)
def build_wds_imagenet_loader(root, batch_size, num_workers):
transforms = get_train_transforms()
dataset = (
wds.WebDataset(glob_by_suffix(pathlib.Path(root), "*.tar"))
# .shuffle(int(os.environ.get("WDS_BUFFER_SIZE", 5000)))
.decode("pil")
.to_tuple("jpg", "cls")
.map_tuple(transforms, lambda x: x)
# .with_length(dataset_len)
)
loader = wds.WebLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
persistent_workers=False,
drop_last=False
)
# loader = loader.with_length(math.ceil(len(dataset)/batch_size))
return loader
def create_dali_pipeline(root, batch_size, num_workers,
image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
dali_cpu=False):
pipe = Pipeline(batch_size, num_workers, device_id=0)
with pipe:
tars = glob_by_suffix(pathlib.Path(root), "*.tar")
print(tars)
images, labels = fn.readers.webdataset(
paths=tars,
ext=["jpg", "cls"],
missing_component_behavior="error",
shard_id=0,
num_shards=1,
random_shuffle=True,
initial_fill=100,
pad_last_batch=True,
dont_use_mmap=True, # If set to True, the Loader will use plain file I/O
# instead of trying to map the file in memory. Mapping provides a small
# performance benefit when accessing a local file system, but most network
# file systems, do not provide optimum performance.
prefetch_queue_depth=2,
read_ahead=True,
name="Reader")
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
# ask nvJPEG to preallocate memory for the biggest sample in ImageNet for CPU and GPU to avoid reallocations in runtime
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
# ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime
preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0
preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0
# img = fn.decoders.image(images, device="mixed", output_type=types.RGB)
# resized = fn.resize(img, device="gpu", resize_shorter=256.)
# images_ = fn.crop_mirror_normalize(
# resized,
# dtype=types.FLOAT,
# crop=(224, 224),
# mean=[0., 0., 0.],
# std=[1., 1., 1.])
crops = fn.decoders.image_random_crop(images,
device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
preallocate_width_hint=preallocate_width_hint,
preallocate_height_hint=preallocate_height_hint,
random_aspect_ratio=[0.8, 1.25],
random_area=[0.1, 1.0],
num_attempts=100)
resize_images = fn.resize(crops,
device=dali_device,
resize_x=image_size,
resize_y=image_size,
interp_type=types.INTERP_TRIANGULAR)
mirror = fn.random.coin_flip(probability=0.5)
processed_images = fn.crop_mirror_normalize(resize_images.gpu(),
dtype=types.FLOAT,
output_layout="CHW",
crop=(image_size, image_size),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
mirror=mirror)
labels = labels.gpu()
pipe.set_outputs(processed_images, labels)
return pipe
class DALIWrapper:
def __init__(self, daliiterator):
self.daliiterator = daliiterator
def __iter__(self):
self._iter = iter(self.daliiterator)
return self
def __next__(self):
datas = next(self._iter)
inputs = datas[0]["data"]
targets = datas[0]["label"].squeeze(-1).long()
return inputs, targets
def __len__(self):
return len(self.daliiterator)
def build_imagenet_dali_loader(root, batch_size, num_workers, dali_cpu):
pipe = create_dali_pipeline(root, batch_size, num_workers, dali_cpu)
loader = DALIGenericIterator(pipe,
output_map=["data", "label"],
auto_reset=True,
last_batch_policy=LastBatchPolicy.DROP,
reader_name="Reader")
loader = DALIWrapper(loader)
return loader
TYPE = os.environ.get("TYPE")
assert TYPE in ["torchvision", "dali_cpu", "dali_gpu"]
ROOT = os.environ.get("ROOT")
BATCHSIZE = 256
NUMWORKERS = int(os.environ.get("NUMWORKERS"))
if TYPE == "torchvision":
dataloader = build_wds_imagenet_loader(ROOT, BATCHSIZE, NUMWORKERS)
elif TYPE == "dali_cpu":
dataloader = build_imagenet_dali_loader(ROOT, BATCHSIZE, NUMWORKERS, dali_cpu=True)
elif TYPE == "dali_gpu":
dataloader = build_imagenet_dali_loader(ROOT, BATCHSIZE, NUMWORKERS, dali_cpu=False)
device = "cuda:0"
total = 0
for iter_, (inputs, targets) in enumerate(dataloader, start=1):
batch_size, *_ = inputs.shape
# print(inputs.shape, targets.shape)
inputs = inputs.to(device=device, non_blocking=False)
targets = targets.to(device=device, non_blocking=False)
if iter_ <= 10:
continue
else:
total += batch_size
if iter_ == 11:
start = time.perf_counter()
print(f"iter={iter_:05d}, batch_size={batch_size:04d}, total={total:06d}")
end = time.perf_counter()
duration = end - start
print(f"Test end, {TYPE} with {NUMWORKERS} workers cost {duration:.2f}s to pre-process {total} images")
print(f"fps={(total/duration):.2f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment