Skip to content

Instantly share code, notes, and snippets.

@MartinWeiss12
Last active October 23, 2024 17:36
Show Gist options
  • Save MartinWeiss12/ce1d948a9d5f482f798010f2431ac20b to your computer and use it in GitHub Desktop.
Save MartinWeiss12/ce1d948a9d5f482f798010f2431ac20b to your computer and use it in GitHub Desktop.
Image Prep
image_width, image_height = 224, 224
data_transforms = {
'train': transforms.Compose([
transforms.Resize((image_height, image_width)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(degrees=15),
transforms.RandomResizedCrop((image_height, image_width), scale=(0.8, 1.0)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
'test': transforms.Compose([
transforms.Resize((image_height, image_width)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
}
class JPGImageFolder(datasets.ImageFolder):
def __init__(self, root, transform=None):
super().__init__(root, transform)
self.imgs = [(p, c) for p, c in self.imgs if p.lower().endswith('.jpg') and '.ipynb' not in p]
self.samples = self.imgs
def find_classes(self, directory):
classes = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d)) and '.ipynb' not in d]
classes.sort()
return classes, {cls_name: i for i, cls_name in enumerate(classes)}
data_dir = 'images'
image_datasets = {
'train': JPGImageFolder(root=os.path.join(data_dir, 'train'), transform=data_transforms['train']),
'test': JPGImageFolder(root=os.path.join(data_dir, 'test'), transform=data_transforms['test'])
}
batch_size = 32
dataloaders = {
'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True),
'test': DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=False)
}
class_names = image_datasets['train'].classes
print(class_names)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment