Last active
July 18, 2024 12:16
-
-
Save nanguoyu/7e0f6a06e2e44136ef9d2b447ab3128e to your computer and use it in GitHub Desktop.
Fast image dataloader for Pytorch models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
@File : fast_torchvision_dataloader.py | |
@Author: Dong Wang | |
@Date : 2024/06/25 | |
@Description : a fast image dataloader for Pytorch models. It tries to use FFCV to speed up your dataloader for vision tasks. | |
You need first install FFCV in your Python ENV and run prepare_ffcv_dataset.py to prepare datasets in FFCV. | |
""" | |
import os | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as T | |
from torchvision import transforms, datasets | |
import numpy as np | |
import torch | |
import torchvision | |
dataset_infor = { | |
'FashionMNIST':{'num_classes':10, 'num_channels':1}, | |
'MNIST':{'num_classes':10, 'num_channels':1}, | |
'ImageNet':{'num_classes':1000, 'num_channels':3}, | |
'CIFAR10':{'num_classes':10, 'num_channels':3}, | |
'CIFAR100':{'num_classes':100, 'num_channels':3}, | |
} | |
def ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline): | |
from ffcv.loader import Loader, OrderOption | |
import os | |
output_dir = './ffcv_datasets' | |
if split=="train": | |
sub="train" | |
else: | |
sub="test" | |
#todo: check file exists. | |
output_file = f'{output_dir}/{dataset_name}_{sub}_ffcv.beton' | |
data_loader = Loader(output_file, batch_size=batch_size, num_workers=num_workers, order=OrderOption.RANDOM if split=="train" else OrderOption.SEQUENTIAL, distributed=False, | |
pipelines={ | |
'image': image_pipeline, | |
'label': label_pipeline | |
}) | |
return data_loader | |
def load_data(split, dataset_name, datadir, batch_size,shuffle,device,num_workers=4): | |
## https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151 | |
# todo: support `nchannels` | |
get_dataset = getattr(datasets, dataset_name) | |
if dataset_name == 'MNIST': | |
mean, std = [0.1307], [0.3081] | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'SVHN': | |
mean=[0.4377, 0.4438, 0.4728] | |
std=[0.1980, 0.2010, 0.1970] | |
try: | |
import ffcv | |
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
elif split == 'test': | |
image_pipeline =[SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train', download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, split='test', download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'CIFAR10': | |
mean=[0.4914, 0.4822, 0.4465] | |
std=[0.2470, 0.2435, 0.2616] | |
try: | |
import ffcv | |
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float32), | |
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
elif split == 'test': | |
image_pipeline =[SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float32), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'CIFAR100': | |
mean=[0.5071, 0.4865, 0.4409] | |
std=[0.2673, 0.2564, 0.2762] | |
try: | |
import ffcv | |
from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
elif split == 'test': | |
image_pipeline =[SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
Convert(torch.float), | |
transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), | |
torchvision.transforms.Normalize(np.array(mean)*255, np.array(std)*255), | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
# https://github.com/pytorch/examples/blob/a38cbfc6f817d9015fc67a6309d3a0be9ff94ab6/imagenet/main.py#L239-L255 | |
tr_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'ImageNet': | |
mean=[0.485, 0.456, 0.406] | |
std=[0.229, 0.224, 0.225] | |
try: | |
import ffcv | |
from ffcv.fields import IntField, RGBImageField | |
from ffcv.writer import DatasetWriter | |
from ffcv.transforms import RandomHorizontalFlip, RandomTranslate, NormalizeImage, Squeeze, RandomHorizontalFlip, ToTorchImage, ToDevice, Convert, ToTensor, Convert, Cutout | |
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder,SimpleRGBImageDecoder | |
from ffcv.fields.basics import IntDecoder | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
Squeeze(), | |
ToDevice(device, non_blocking=True), | |
] | |
if split =='train': | |
image_pipeline= [RandomResizedCropRGBImageDecoder((224, 224)), | |
RandomHorizontalFlip(), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
NormalizeImage(np.array(mean)*255, np.array(std)*255, np.float16), | |
] | |
elif split == 'test': | |
image_pipeline =[ | |
CenterCropRGBImageDecoder((224,224), ratio=224/256), | |
ToTensor(), | |
ToDevice(device, non_blocking=True), | |
ToTorchImage(), | |
NormalizeImage(np.array(mean)*255, np.array(std)*255, np.float16) | |
] | |
data_loader = ffcv_data(dataset_name, split, num_workers, batch_size, image_pipeline, label_pipeline) | |
print("Using FFCV dataset.") | |
except ImportError: | |
normalize = transforms.Normalize(mean=mean, std=std) | |
# https://github.com/pytorch/examples/blob/a38cbfc6f817d9015fc67a6309d3a0be9ff94ab6/imagenet/main.py#L239-L255 | |
tr_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train', transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, split='val', transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
elif dataset_name == 'FashionMNIST': | |
mean=[0.5] | |
std=[0.5] | |
normalize = transforms.Normalize(mean=mean, std=std) | |
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize]) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform) | |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,num_workers=num_workers) | |
print("Using PyTorch dataset.") | |
else: | |
raise NotImplementedError(f"Non-supported dataset") | |
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
@File : prepare_ffcv_dataset.py | |
@Author: Dong Wang | |
@Date : 2024/06/25 | |
@Description : prepare FFCV dataset files. You need first install FFCV in your environment: https://github.com/libffcv/ffcv | |
""" | |
import os | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as T | |
from torchvision import transforms, datasets | |
import numpy as np | |
import torch | |
import torchvision | |
dataset_infor = { | |
'FashionMNIST':{'num_classes':10, 'num_channels':1}, | |
'MNIST':{'num_classes':10, 'num_channels':1}, | |
'ImageNet':{'num_classes':1000, 'num_channels':3}, | |
'CIFAR10':{'num_classes':10, 'num_channels':3}, | |
'CIFAR100':{'num_classes':100, 'num_channels':3}, | |
} | |
def prepare_data(split, dataset_name, datadir): | |
get_dataset = getattr(datasets, dataset_name) | |
import ffcv | |
from ffcv.fields import IntField, RGBImageField | |
from ffcv.writer import DatasetWriter | |
import os | |
output_dir = './ffcv_datasets' | |
os.makedirs(output_dir, exist_ok=True) | |
if split=="train": | |
sub="train" | |
else: | |
sub="test" | |
output_file = f'{output_dir}/{dataset_name}_{sub}_ffcv.beton' | |
if dataset_name == 'SVHN': | |
image_field = RGBImageField(write_mode='smart', max_resolution=32, jpeg_quality=90) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train', download=True) | |
else: | |
dataset = get_dataset(root=datadir, split='test', download=True) | |
elif dataset_name== 'ImageNet': | |
image_field = RGBImageField(write_mode='smart', max_resolution=256, jpeg_quality=90) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, split='train') | |
else: | |
dataset = get_dataset(root=datadir, split='val') | |
else: | |
image_field = RGBImageField(write_mode='smart', max_resolution=32, jpeg_quality=90) | |
if split == 'train': | |
dataset = get_dataset(root=datadir, train=True, download=True) | |
else: | |
dataset = get_dataset(root=datadir, train=False, download=True) | |
write_config = { | |
'image': image_field, | |
'label': IntField() | |
} | |
writer = DatasetWriter(output_file, write_config) | |
writer.from_indexed_dataset(dataset) | |
# Now you can generate FFCV dataset before use it for training. | |
# CIFAT10 | |
# prepare_data(split="train", dataset_name="CIFAR10", datadir="datasets") | |
# prepare_data(split="test", dataset_name="CIFAR10", datadir="datasets") | |
# CIFAR100 | |
# prepare_data(split="train", dataset_name="CIFAR100", datadir="datasets") | |
# prepare_data(split="test", dataset_name="CIFAR100", datadir="datasets") | |
# For ImageNet, `~/data/ImageNet` should be a folder containing files ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar, ILSVRC2012_img_val.tar | |
# prepare_data(split="train", dataset_name="ImageNet", datadir="~/data/ImageNet") | |
# prepare_data(split="test", dataset_name="ImageNet", datadir="~/data/ImageNet") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Install Mamba
Install Packages
Note: FFCV is only avaiable for python<3.12