Skip to content

Instantly share code, notes, and snippets.

@Lexie88rus
Created August 8, 2019 18:57
Show Gist options
  • Save Lexie88rus/5983e17104e1e849f12fbd5fab8d0dc2 to your computer and use it in GitHub Desktop.
Save Lexie88rus/5983e17104e1e849f12fbd5fab8d0dc2 to your computer and use it in GitHub Desktop.
Example of usage of Augmenter package with PyTorch
# Define the demo dataset
class DogDataset3(Dataset):
'''
Sample dataset for Augmentor demonstration.
The dataset will consist of just one sample image.
'''
def __init__(self, image):
self.image = image
def __len__(self):
return 1 # return 1 as we have only one image
def __getitem__(self, idx):
# Returns the augmented image
# Initialize the pipeline
p = Augmentor.DataPipeline([[np.array(image)]])
# Apply augmentations
p.rotate(0.5, max_left_rotation=10, max_right_rotation=10) # rotate the image with 50% probability
p.shear(0.5, max_shear_left = 10, max_shear_right = 10) # shear the image with 50% probability
p.zoom_random(0.5, percentage_area=0.7) # zoom randomly with 50% probability
# Sample from augmentation pipeline
images_aug = p.sample(1)
# Get augmented image
augmented_image = images_aug[0][0]
# convert to tensor and return the result
return TF.to_tensor(augmented_image)
# Initialize the dataset, pass the augmentation pipeline as an argument to init function
train_ds = DogDataset3(image)
# Initialize the dataloader
trainloader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment