Last active
August 8, 2019 18:44
-
-
Save Lexie88rus/cfe4346475ecd7508b16621239e225d3 to your computer and use it in GitHub Desktop.
Using imgaug with PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Import PyTorch | |
| import torchvision.transforms.functional as TF | |
| from torch.utils.data import Dataset, DataLoader | |
| # Define the augmentations | |
| AUG_TRAIN = aug_pipeline # use our pipeline as train augmentations | |
| # Define the demo dataset | |
| class DogDataset(Dataset): | |
| ''' | |
| Sample dataset for imgaug demonstration. | |
| The dataset will consist of just one sample image. | |
| ''' | |
| def __init__(self, image, augmentations = None): | |
| self.image = image | |
| self.augmentations = augmentations # save the augmentations | |
| def __len__(self): | |
| return 1 # return 1 as we have only one image | |
| def __getitem__(self, idx): | |
| # return the augmented image | |
| return TF.to_tensor(self.augmentations.augment_image(self.image)) | |
| # Load the augmented data | |
| # Initialize the dataset, pass the augmentation pipeline as an argument to init function | |
| train_ds = DogDataset(image, augmentations = AUG_TRAIN) | |
| # Initilize the dataloader | |
| trainloader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment