Last active
February 13, 2019 16:04
-
-
Save iKrishneel/cdcf82ec93b8e2524e722c510b08a5ea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#! -*- coding: utf-8 -*- | |
import os | |
import sys | |
import random | |
import numpy as np | |
import cv2 as cv | |
from imgaug import augmenters | |
from sklearn.model_selection import train_test_split | |
class Dataloader(object): | |
def __init__(self, dataset_dir, list_fname, input_shape): | |
assert len(input_shape) is 3, 'Input shape should be of lenght 3' | |
assert os.path.isdir(dataset_dir), 'Dataset directory not found at {}'.format(dataset_dir) | |
filename = os.path.join(dataset_dir, list_fname) | |
assert os.path.isfile(filename), 'Data list file not found at {}'.format(filename) | |
# self.read_directory(dataset_dir) | |
self.read_textfile(dataset_dir, filename) | |
assert len(self.dataset) is not 0, 'No images found' | |
self.random_index = lambda: random.randint(0, len(self.dataset)-1) | |
self.input_shape = input_shape[:2] | |
def load(self): | |
index1 = self.random_index() | |
index2 = self.random_index() | |
index2 = index1 if random.randint(0, 10) % 2 == 0 else index2 | |
label = 1 if index1 == index2 else 0 | |
idx1 = random.randint(0, len(self.dataset[index1])-1) | |
idx2 = random.randint(0, len(self.dataset[index2])-1) | |
im_path1 = self.dataset[index1][idx1] | |
im_path2 = self.dataset[index2][idx2] | |
image1, image2 = self.read_image([im_path1, im_path2]) | |
image1 = self.color_space_argumentation(image1) | |
image2 = self.color_space_argumentation(image2) | |
image1, image2 = self.resize_image([image1, image2]) | |
cv.imshow('image', np.hstack([image1, image2])) | |
cv.waitKey(0) | |
return dict(input_a=image1, input_b=image2, label=label) | |
def resize_image(self, images): | |
return [ | |
cv.resize(image, self.input_shape) | |
for image in images | |
] | |
def color_space_argumentation(self, image): | |
seq = augmenters.Sequential([ | |
augmenters.OneOf([ | |
augmenters.GaussianBlur((0, 3.0)), | |
augmenters.AverageBlur(k=(2, 7)), | |
augmenters.MedianBlur(k=(3, 7)), | |
]), | |
augmenters.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), | |
augmenters.Add((-2, 21), per_channel=0.5), | |
augmenters.Multiply((0.75, 1.25), per_channel=0.5), | |
augmenters.Grayscale(alpha=(0.0, 0.50)), | |
augmenters.Fliplr(0.5), | |
], random_order=False) | |
return seq.augment_image(image) | |
def read_image(self, im_paths): | |
images = [ | |
cv.imread(im_path, cv.IMREAD_COLOR) | |
for im_path in im_paths | |
] | |
return images | |
def read_directory(self, dataset_dir): | |
folders = [ | |
os.path.join(dataset_dir, os.path.join(folder, 'face')) | |
for folder in os.listdir(dataset_dir) | |
] | |
self.dataset = {} | |
for index, folder in enumerate(folders): | |
files = [ | |
os.path.join(folder, ifile) | |
for ifile in os.listdir(folder) | |
if len(ifile.split('.')) == 2 | |
if ifile.split('.')[1] in ['jpg', 'png'] | |
] | |
self.dataset[index] = files | |
def read_textfile(self, dataset_dir, filename): | |
folders = [ | |
os.path.join(dataset_dir, os.path.join(line.rstrip('\n'), 'face')) | |
for line in open(filename) | |
] | |
self.dataset = {} | |
for index, folder in enumerate(folders): | |
files = [ | |
os.path.join(folder, ifile) | |
for ifile in os.listdir(folder) | |
if len(ifile.split('.')) == 2 | |
if ifile.split('.')[1] in ['jpg', 'png'] | |
] | |
self.dataset[index] = files | |
def main(argv): | |
d = Dataloader(argv[1], 'train.txt', (224, 224, 3)) | |
x = d.load() | |
# print (x) | |
if __name__ == '__main__': | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment