Created
April 25, 2019 11:36
-
-
Save azkalot1/f00f6c7d34137a8170cb4517718f8f50 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.autograd import Variable | |
from torchvision import datasets, transforms | |
from torch.optim import Optimizer | |
from torch.utils import data | |
class DataGenerator(data.Dataset): | |
"""Generates dataset for loading. | |
Args: | |
ids: images ids | |
labels: labels of images (1/0) | |
augment: image augmentation from albumentations | |
imdir: path tpo folder with images | |
""" | |
def __init__(self, ids, labels, augment, imdir): | |
'Initialization' | |
self.ids, self.labels = ids, labels | |
self.augment = augment | |
self.imdir = imdir | |
def __len__(self): | |
return len(self.ids) | |
def __getitem__(self, idx): | |
imid = self.ids[idx] | |
y = self.labels[idx] | |
X = self.__load_image(imid) | |
return X, np.expand_dims(y,0) | |
def __load_image(self, imid): | |
imid = imid+'.tif' | |
im = imread(os.path.join(self.imdir, imid)) | |
if self.augment!=None: | |
augmented = self.augment(image=im) | |
im = augmented['image'] | |
im = im/255.0 | |
im = np.rollaxis(im, -1) | |
return im |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment