Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created May 20, 2023 03:39
Show Gist options
  • Save calebrob6/f28adb7561ff7525d3e392d56c14855c to your computer and use it in GitHub Desktop.
Save calebrob6/f28adb7561ff7525d3e392d56c14855c to your computer and use it in GitHub Desktop.
Simple datamodule for Imagenet
import torchvision
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader
from torchvision import transforms
class ImagenetDataModule(LightningDataModule):
train_transforms = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
test_transforms = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def __init__(self, root, batch_size=64, num_workers=4, pin_memory=True):
super().__init__()
self.root = root
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage):
self.train_dataset = torchvision.datasets.ImageNet(
root=self.root, split="train", transform=self.train_transforms
)
self.val_dataset = torchvision.datasets.ImageNet(
root=self.root, split="val", transform=self.test_transforms
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment