Last active
May 5, 2020 19:08
-
-
Save burrussmp/e6bcc6b6ac0d0da9ae1b6a818cb16f0e to your computer and use it in GitHub Desktop.
An example of a PyTorch Data Loader that uses ListDataset and a load_data function.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Helper function for the Pytorch data loader | |
@params | |
type: string | |
Specifies if training (train), validation (valid), or testing (test) list | |
should be generated | |
@return | |
mlist: A nested Python list | |
A list of number of input-output pairs where each element is a list of size 2 | |
The first element is the path to the .npy input file and the second element is | |
the path to the .npy of the one-hot-encoded segmentation map | |
""" | |
def get_list(type='train'): | |
assert type in ['train','valid','test'], \ | |
print('Type must be train, valid, or test') | |
path = os.path.join(DATADIR,type) | |
num_items = len(os.listdir(path)) | |
items = int(num_items/2) | |
mlist = [] | |
for i in range(items): | |
path_to_img = os.path.join(path,'img_{}.npy'.format(i)) | |
path_to_target = os.path.join(path,'target_{}.npy'.format(i)) | |
mlist.append([path_to_img,path_to_target]) | |
return mlist | |
""" | |
Pytorch Data loader and perform data augmentation | |
1. Randomly translates input | |
2. Draws random shape on input | |
3. Randomly changes the thinness of the input | |
4. Randomly changes the brightness | |
@params | |
y: Python list (len == 2) | |
y[0]: A path to a numpy array that contains the input HxWx1 | |
y[1]: A path to a numpy array that contains the segmentation matrix HxWx(num_labels+1) | |
@return | |
Python dictionary | |
key: 'src' (nd.array) | |
The augmented input HxWx1 | |
key: 'target' (nd.array) | |
The augmented target matrix one-hot-encoded HxWx(num_labels+1) | |
""" | |
def load_data(line): | |
path_to_image = line[0] | |
path_to_target = line[1] | |
img = np.load(path_to_image) | |
target = np.load(path_to_target) | |
smaller = np.squeeze(img) | |
img,target = random_translation(smaller,target) # random translation | |
src = draw_random_shape(img) # draw random shape | |
src = randomly_thin(src,p=0.35) # possible thin | |
src = randomly_change_brightness(src) # change brightness | |
src = np.expand_dims(src,axis=0) # because pytorch likes it like CxHxW | |
return {'src': src, 'target': target} | |
training_list = get_list('train') | |
training_dataset = ListDataset(training_list, load_data) | |
train_loader = DataLoader(dataset=training_dataset,batch_size=32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment