Skip to content

Instantly share code, notes, and snippets.

@nanguoyu
Last active July 18, 2024 12:16
Show Gist options
  • Save nanguoyu/7e0f6a06e2e44136ef9d2b447ab3128e to your computer and use it in GitHub Desktop.
Save nanguoyu/7e0f6a06e2e44136ef9d2b447ab3128e to your computer and use it in GitHub Desktop.
Fast image dataloader for Pytorch models
"""
@File : fast_torchvision_dataloader.py
@Author: Dong Wang
@Date : 2024/06/25
@Description : a fast image dataloader for Pytorch models. It tries to use FFCV to speed up your dataloader for vision tasks.
You need first install FFCV in your Python ENV and run prepare_ffcv_dataset.py to prepare datasets in FFCV.
"""
import os
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision import transforms, datasets
import numpy as np
import torch
import torchvision
dataset_infor = {
'FashionMNIST':{'num_classes':10, 'num_channels':1},
'MNIST':{'num_classes':10, 'num_channels':1},
'ImageNet':{'num_classes':1000, 'num_channels':3},
'CIFAR10':{'num_classes':10, 'num_channels':3},
'CIFAR100':{'num_classes':100, 'num_channels':3},
}
def ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline):
from ffcv.loader import Loader, OrderOption
import os
output_dir = './ffcv_datasets'
if split=="train":
sub="train"
else:
sub="test"
#todo: check file exists.
output_file = f'{output_dir}/{dataset_name}_{sub}_ffcv.beton'
data_loader = Loader(output_file, batch_size=batch_size, num_workers=num_workers, order=OrderOption.RANDOM if split=="train" else OrderOption.SEQUENTIAL, distributed=False,
pipelines={
'image': image_pipeline,
'label': label_pipeline
})
return data_loader
def load_data(split, dataset_name, datadir, batch_size,shuffle,device,num_workers=4):
## https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
# todo: support `nchannels`
get_dataset = getattr(datasets, dataset_name)
if dataset_name == 'MNIST':
mean, std = [0.1307], [0.3081]
normalize = transforms.Normalize(mean=mean, std=std)
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
if split == 'train':
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers)
print("Using PyTorch dataset.")
elif dataset_name == 'SVHN':
mean=[0.4377, 0.4438, 0.4728]
std=[0.1980, 0.2010, 0.1970]
try:
import ffcv
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder
from ffcv.fields.basics import IntDecoder
label_pipeline = [
IntDecoder(),
ToTensor(),
Squeeze(),
ToDevice(device, non_blocking=True),
]
if split =='train':
image_pipeline= [SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
Convert(torch.float),
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15),
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255),
]
elif split == 'test':
image_pipeline =[SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
Convert(torch.float),
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255),
]
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline)
print("Using FFCV dataset.")
except ImportError:
normalize = transforms.Normalize(mean=mean, std=std)
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
if split == 'train':
dataset = get_dataset(root=datadir, split='train', download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, split='test', download=True, transform=val_transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers)
print("Using PyTorch dataset.")
elif dataset_name == 'CIFAR10':
mean=[0.4914, 0.4822, 0.4465]
std=[0.2470, 0.2435, 0.2616]
try:
import ffcv
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder
from ffcv.fields.basics import IntDecoder
label_pipeline = [
IntDecoder(),
ToTensor(),
Squeeze(),
ToDevice(device, non_blocking=True),
]
if split =='train':
image_pipeline= [SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
Convert(torch.float32),
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15),
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255),
]
elif split == 'test':
image_pipeline =[SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
Convert(torch.float32),
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255),
]
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline)
print("Using FFCV dataset.")
except ImportError:
normalize = transforms.Normalize(mean=mean, std=std)
tr_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.ToTensor(), normalize])
if split == 'train':
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers)
print("Using PyTorch dataset.")
elif dataset_name == 'CIFAR100':
mean=[0.5071, 0.4865, 0.4409]
std=[0.2673, 0.2564, 0.2762]
try:
import ffcv
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder
from ffcv.fields.basics import IntDecoder
label_pipeline = [
IntDecoder(),
ToTensor(),
Squeeze(),
ToDevice(device, non_blocking=True),
]
if split =='train':
image_pipeline= [SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
Convert(torch.float),
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255),
]
elif split == 'test':
image_pipeline =[SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
Convert(torch.float),
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15),
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255),
]
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline)
print("Using FFCV dataset.")
except ImportError:
normalize = transforms.Normalize(mean=mean, std=std)
# https://github.com/pytorch/examples/blob/a38cbfc6f817d9015fc67a6309d3a0be9ff94ab6/imagenet/main.py#L239-L255
tr_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.ToTensor(), normalize])
if split == 'train':
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers)
print("Using PyTorch dataset.")
elif dataset_name == 'ImageNet':
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
try:
import ffcv
from ffcv.fields import IntField, RGBImageField
from ffcv.writer import DatasetWriter
from ffcv.transforms import RandomHorizontalFlip, RandomTranslate, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder
from ffcv.fields.basics import IntDecoder
label_pipeline = [
IntDecoder(),
ToTensor(),
Squeeze(),
ToDevice(device, non_blocking=True),
]
if split =='train':
image_pipeline= [RandomResizedCropRGBImageDecoder((224, 224)),
RandomHorizontalFlip(),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
NormalizeImage(np.array(mean)*255, np.array(std)*255, np.float16),
]
elif split == 'test':
image_pipeline =[
CenterCropRGBImageDecoder((224,224), ratio=224/256),
ToTensor(),
ToDevice(device, non_blocking=True),
ToTorchImage(),
NormalizeImage(np.array(mean)*255, np.array(std)*255, np.float16)
]
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline)
print("Using FFCV dataset.")
except ImportError:
normalize = transforms.Normalize(mean=mean, std=std)
# https://github.com/pytorch/examples/blob/a38cbfc6f817d9015fc67a6309d3a0be9ff94ab6/imagenet/main.py#L239-L255
tr_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
if split == 'train':
dataset = get_dataset(root=datadir, split='train', transform=tr_transform)
else:
dataset = get_dataset(root=datadir, split='val', transform=val_transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers)
print("Using PyTorch dataset.")
elif dataset_name == 'FashionMNIST':
mean=[0.5]
std=[0.5]
normalize = transforms.Normalize(mean=mean, std=std)
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
if split == 'train':
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers)
print("Using PyTorch dataset.")
else:
raise NotImplementedError(f"Non-supported dataset")
return data_loader
"""
@File : prepare_ffcv_dataset.py
@Author: Dong Wang
@Date : 2024/06/25
@Description : prepare FFCV dataset files. You need first install FFCV in your environment: https://github.com/libffcv/ffcv
"""
import os
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision import transforms, datasets
import numpy as np
import torch
import torchvision
dataset_infor = {
'FashionMNIST':{'num_classes':10, 'num_channels':1},
'MNIST':{'num_classes':10, 'num_channels':1},
'ImageNet':{'num_classes':1000, 'num_channels':3},
'CIFAR10':{'num_classes':10, 'num_channels':3},
'CIFAR100':{'num_classes':100, 'num_channels':3},
}
def prepare_data(split, dataset_name, datadir):
get_dataset = getattr(datasets, dataset_name)
import ffcv
from ffcv.fields import IntField, RGBImageField
from ffcv.writer import DatasetWriter
import os
output_dir = './ffcv_datasets'
os.makedirs(output_dir, exist_ok=True)
if split=="train":
sub="train"
else:
sub="test"
output_file = f'{output_dir}/{dataset_name}_{sub}_ffcv.beton'
if dataset_name == 'SVHN':
image_field = RGBImageField(write_mode='smart', max_resolution=32, jpeg_quality=90)
if split == 'train':
dataset = get_dataset(root=datadir, split='train', download=True)
else:
dataset = get_dataset(root=datadir, split='test', download=True)
elif dataset_name== 'ImageNet':
image_field = RGBImageField(write_mode='smart', max_resolution=256, jpeg_quality=90)
if split == 'train':
dataset = get_dataset(root=datadir, split='train')
else:
dataset = get_dataset(root=datadir, split='val')
else:
image_field = RGBImageField(write_mode='smart', max_resolution=32, jpeg_quality=90)
if split == 'train':
dataset = get_dataset(root=datadir, train=True, download=True)
else:
dataset = get_dataset(root=datadir, train=False, download=True)
write_config = {
'image': image_field,
'label': IntField()
}
writer = DatasetWriter(output_file, write_config)
writer.from_indexed_dataset(dataset)
# Now you can generate FFCV dataset before use it for training.
# CIFAT10
# prepare_data(split="train", dataset_name="CIFAR10", datadir="datasets")
# prepare_data(split="test", dataset_name="CIFAR10", datadir="datasets")
# CIFAR100
# prepare_data(split="train", dataset_name="CIFAR100", datadir="datasets")
# prepare_data(split="test", dataset_name="CIFAR100", datadir="datasets")
# For ImageNet, `~/data/ImageNet` should be a folder containing files ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar, ILSVRC2012_img_val.tar
# prepare_data(split="train", dataset_name="ImageNet", datadir="~/data/ImageNet")
# prepare_data(split="test", dataset_name="ImageNet", datadir="~/data/ImageNet")
@nanguoyu
Copy link
Author

nanguoyu commented Jun 25, 2024

How to use

0. Install FFCV

Please ref to https://github.com/libffcv/ffcv

1. Prepare dataset for FFCV

In the prepare_ffcv_dataset.py

Uncomment a line for your dataset and split (train of test)

prepare_data(split="train", dataset_name="ImageNet", datadir="~/data/ImageNet")

Then, run the data-preparation script for your dataset once. Then you can use it for training.

Python prepare_ffcv_dataset.py

2. Use it in your training/ testing code.

from fast_torchvision_dataloader import load_data

data_loader = load_data("train", args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=True,device=device,num_workers=4)

# rest of training / validation proceeds identically
for epoch in range(epochs):
    for img, label in data_loader:
          #train or test your model
          ...

@nanguoyu
Copy link
Author

Install Mamba

Mamba is a smarter version of Miniconda

curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
conda init zsh

Install Packages

Note: FFCV is only avaiable for python<3.12

mamba create -y -n myenv python=3.11 opencv pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
mamba activate myenv
mamba install  pandas scipy tqdm cupy pkg-config libjpeg-turbo numba -y
pip3 install torchopt ultralytics-thop wandb timm==0.9.16 torchist==0.2.3 matplotlib torchopt
pip3 install ffcv

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment