Skip to content

Instantly share code, notes, and snippets.

@iKrishneel
Last active February 13, 2019 16:04
Show Gist options
  • Save iKrishneel/cdcf82ec93b8e2524e722c510b08a5ea to your computer and use it in GitHub Desktop.
Save iKrishneel/cdcf82ec93b8e2524e722c510b08a5ea to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
#! -*- coding: utf-8 -*-
import os
import sys
import random
import numpy as np
import cv2 as cv
from imgaug import augmenters
from sklearn.model_selection import train_test_split
class Dataloader(object):
def __init__(self, dataset_dir, list_fname, input_shape):
assert len(input_shape) is 3, 'Input shape should be of lenght 3'
assert os.path.isdir(dataset_dir), 'Dataset directory not found at {}'.format(dataset_dir)
filename = os.path.join(dataset_dir, list_fname)
assert os.path.isfile(filename), 'Data list file not found at {}'.format(filename)
# self.read_directory(dataset_dir)
self.read_textfile(dataset_dir, filename)
assert len(self.dataset) is not 0, 'No images found'
self.random_index = lambda: random.randint(0, len(self.dataset)-1)
self.input_shape = input_shape[:2]
def load(self):
index1 = self.random_index()
index2 = self.random_index()
index2 = index1 if random.randint(0, 10) % 2 == 0 else index2
label = 1 if index1 == index2 else 0
idx1 = random.randint(0, len(self.dataset[index1])-1)
idx2 = random.randint(0, len(self.dataset[index2])-1)
im_path1 = self.dataset[index1][idx1]
im_path2 = self.dataset[index2][idx2]
image1, image2 = self.read_image([im_path1, im_path2])
image1 = self.color_space_argumentation(image1)
image2 = self.color_space_argumentation(image2)
image1, image2 = self.resize_image([image1, image2])
cv.imshow('image', np.hstack([image1, image2]))
cv.waitKey(0)
return dict(input_a=image1, input_b=image2, label=label)
def resize_image(self, images):
return [
cv.resize(image, self.input_shape)
for image in images
]
def color_space_argumentation(self, image):
seq = augmenters.Sequential([
augmenters.OneOf([
augmenters.GaussianBlur((0, 3.0)),
augmenters.AverageBlur(k=(2, 7)),
augmenters.MedianBlur(k=(3, 7)),
]),
augmenters.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)),
augmenters.Add((-2, 21), per_channel=0.5),
augmenters.Multiply((0.75, 1.25), per_channel=0.5),
augmenters.Grayscale(alpha=(0.0, 0.50)),
augmenters.Fliplr(0.5),
], random_order=False)
return seq.augment_image(image)
def read_image(self, im_paths):
images = [
cv.imread(im_path, cv.IMREAD_COLOR)
for im_path in im_paths
]
return images
def read_directory(self, dataset_dir):
folders = [
os.path.join(dataset_dir, os.path.join(folder, 'face'))
for folder in os.listdir(dataset_dir)
]
self.dataset = {}
for index, folder in enumerate(folders):
files = [
os.path.join(folder, ifile)
for ifile in os.listdir(folder)
if len(ifile.split('.')) == 2
if ifile.split('.')[1] in ['jpg', 'png']
]
self.dataset[index] = files
def read_textfile(self, dataset_dir, filename):
folders = [
os.path.join(dataset_dir, os.path.join(line.rstrip('\n'), 'face'))
for line in open(filename)
]
self.dataset = {}
for index, folder in enumerate(folders):
files = [
os.path.join(folder, ifile)
for ifile in os.listdir(folder)
if len(ifile.split('.')) == 2
if ifile.split('.')[1] in ['jpg', 'png']
]
self.dataset[index] = files
def main(argv):
d = Dataloader(argv[1], 'train.txt', (224, 224, 3))
x = d.load()
# print (x)
if __name__ == '__main__':
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment