Last active
July 18, 2024 12:16
-
-
Save nanguoyu/7e0f6a06e2e44136ef9d2b447ab3128e to your computer and use it in GitHub Desktop.
Fast image dataloader for Pytorch models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
@File : fast_torchvision_dataloader.py | |
@Author: Dong Wang | |
@Date : 2024/06/25 | |
@Description : a fast image dataloader for Pytorch models. It tries to use FFCV to speed up your dataloader for vision tasks. | |
You need first install FFCV in your Python ENV and run prepare_ffcv_dataset.py to prepare datasets in FFCV. | |
""" | |
import os | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as T | |
from torchvision import transforms, datasets | |
import numpy as np | |
import torch | |
import torchvision | |
dataset_infor = { | |
'FashionMNIST':{'num_classes':10, 'num_channels':1}, | |
'MNIST':{'num_classes':10, 'num_channels':1}, | |
'ImageNet':{'num_classes':1000, 'num_channels':3}, | |
'CIFAR10':{'num_classes':10, 'num_channels':3}, | |
'CIFAR100':{'num_classes':100, 'num_channels':3}, | |
} | |
def ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline): | |
from ffcv.loader import Loader, OrderOption | |
import os | |
output_dir = './ffcv_datasets' | |
if split=="train": | |
sub="train" | |
else: | |
sub="test" | |
#todo: check file exists. | |
output_file = f'{output_dir}/{dataset_name}_{sub}_ffcv.beton' | |
data_loader = Loader(output_file, batch_size=batch_size, num_workers=num_workers, order=OrderOption.RANDOM if split=="train" else OrderOption.SEQUENTIAL, distributed=False, | |
pipelines={ | |
'image': image_pipeline, | |
'label': label_pipeline | |
}) | |
return data_loader | |
def load_data(split, dataset_name, datadir, batch_size,shuffle,device,num_workers=4): | |
## https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151 | |
# todo: support `nchannels` | |
get_dataset = getattr(datasets, dataset_name) | |
if dataset_name == 'MNIST': | |
mean, std = [0.1307], [0.3081] | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'SVHN': | |
mean=[0.4377, 0.4438, 0.4728] | |
std=[0.1980, 0.2010, 0.1970] | |
try: | |
import ffcv | |
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
elif split == 'test': | |
image_pipeline =[SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train', download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, split='test', download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'CIFAR10': | |
mean=[0.4914, 0.4822, 0.4465] | |
std=[0.2470, 0.2435, 0.2616] | |
try: | |
import ffcv | |
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float32), | |
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
elif split == 'test': | |
image_pipeline =[SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float32), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'CIFAR100': | |
mean=[0.5071, 0.4865, 0.4409] | |
std=[0.2673, 0.2564, 0.2762] | |
try: | |
import ffcv | |
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
elif split == 'test': | |
image_pipeline =[SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
# https://github.com/pytorch/examples/blob/a38cbfc6f817d9015fc67a6309d3a0be9ff94ab6/imagenet/main.py#L239-L255 | |
tr_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'ImageNet': | |
mean=[0.485, 0.456, 0.406] | |
std=[0.229, 0.224, 0.225] | |
try: | |
import ffcv | |
from ffcv.fields import IntField, RGBImageField | |
from ffcv.writer import DatasetWriter | |
from ffcv.transforms import RandomHorizontalFlip, RandomTranslate, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [RandomResizedCropRGBImageDecoder((224, 224)), | |
RandomHorizontalFlip(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
NormalizeImage(np.array(mean)*255, np.array(std)*255, np.float16), | |
] | |
elif split == 'test': | |
image_pipeline =[ | |
CenterCropRGBImageDecoder((224,224), ratio=224/256), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
NormalizeImage(np.array(mean)*255, np.array(std)*255, np.float16) | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
# https://github.com/pytorch/examples/blob/a38cbfc6f817d9015fc67a6309d3a0be9ff94ab6/imagenet/main.py#L239-L255 | |
tr_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train', transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, split='val', transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'FashionMNIST': | |
mean=[0.5] | |
std=[0.5] | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
else: | |
raise NotImplementedError(f"Non-supported dataset") | |
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
@File : prepare_ffcv_dataset.py | |
@Author: Dong Wang | |
@Date : 2024/06/25 | |
@Description : prepare FFCV dataset files. You need first install FFCV in your environment: https://github.com/libffcv/ffcv | |
""" | |
import os | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as T | |
from torchvision import transforms, datasets | |
import numpy as np | |
import torch | |
import torchvision | |
dataset_infor = { | |
'FashionMNIST':{'num_classes':10, 'num_channels':1}, | |
'MNIST':{'num_classes':10, 'num_channels':1}, | |
'ImageNet':{'num_classes':1000, 'num_channels':3}, | |
'CIFAR10':{'num_classes':10, 'num_channels':3}, | |
'CIFAR100':{'num_classes':100, 'num_channels':3}, | |
} | |
def prepare_data(split, dataset_name, datadir): | |
get_dataset = getattr(datasets, dataset_name) | |
import ffcv | |
from ffcv.fields import IntField, RGBImageField | |
from ffcv.writer import DatasetWriter | |
import os | |
output_dir = './ffcv_datasets' | |
os.makedirs(output_dir, exist_ok=True) | |
if split=="train": | |
sub="train" | |
else: | |
sub="test" | |
output_file = f'{output_dir}/{dataset_name}_{sub}_ffcv.beton' | |
if dataset_name == 'SVHN': | |
image_field = RGBImageField(write_mode='smart', max_resolution=32, jpeg_quality=90) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train', download=True) | |
else: | |
dataset = get_dataset(root=datadir, split='test', download=True) | |
elif dataset_name== 'ImageNet': | |
image_field = RGBImageField(write_mode='smart', max_resolution=256, jpeg_quality=90) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train') | |
else: | |
dataset = get_dataset(root=datadir, split='val') | |
else: | |
image_field = RGBImageField(write_mode='smart', max_resolution=32, jpeg_quality=90) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True) | |
write_config = { | |
'image': image_field, | |
'label': IntField() | |
} | |
writer = DatasetWriter(output_file, write_config) | |
writer.from_indexed_dataset(dataset) | |
# Now you can generate FFCV dataset before use it for training. | |
# CIFAT10 | |
# prepare_data(split="train", dataset_name="CIFAR10", datadir="datasets") | |
# prepare_data(split="test", dataset_name="CIFAR10", datadir="datasets") | |
# CIFAR100 | |
# prepare_data(split="train", dataset_name="CIFAR100", datadir="datasets") | |
# prepare_data(split="test", dataset_name="CIFAR100", datadir="datasets") | |
# For ImageNet, `~/data/ImageNet` should be a folder containing files ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar, ILSVRC2012_img_val.tar | |
# prepare_data(split="train", dataset_name="ImageNet", datadir="~/data/ImageNet") | |
# prepare_data(split="test", dataset_name="ImageNet", datadir="~/data/ImageNet") |
Install Mamba
Mamba is a smarter version of Miniconda
curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
conda init zsh
Install Packages
Note: FFCV is only avaiable for python<3.12
mamba create -y -n myenv python=3.11 opencv pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
mamba activate myenv
mamba install pandas scipy tqdm cupy pkg-config libjpeg-turbo numba -y
pip3 install torchopt ultralytics-thop wandb timm==0.9.16 torchist==0.2.3 matplotlib torchopt
pip3 install ffcv
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to use
0. Install FFCV
Please ref to https://github.com/libffcv/ffcv
1. Prepare dataset for FFCV
In the
prepare_ffcv_dataset.py
Uncomment a line for your dataset and split (train of test)
Then, run the data-preparation script for your dataset once. Then you can use it for training.
2. Use it in your training/ testing code.