Created
February 4, 2018 14:09
-
-
Save dfaker/1b154f548cf96ccac71bde9fc739a41a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils import get_image_paths, load_images, stack_images | |
from training_data import get_training_data | |
import random | |
import numpy | |
import cv2 | |
from keras.models import Model | |
from keras.layers import Input, Dense, Flatten, Reshape, concatenate, Add, add, Dropout | |
from keras.layers.advanced_activations import LeakyReLU | |
from keras.layers.convolutional import Conv2D | |
from keras.initializers import RandomNormal | |
from keras.optimizers import Adam | |
from keras.utils import conv_utils | |
from keras.engine.topology import Layer | |
import keras.backend as K | |
import tensorflow as tf | |
import time | |
from keras_contrib.losses import DSSIMObjective | |
images_A = get_image_paths( "data/A" ) | |
images_B = get_image_paths( "data/B" ) | |
images_C = get_image_paths( "data/C" ) | |
images_D = get_image_paths( "data/D" ) | |
minImages = max(len(images_A),len(images_B),len(images_C),len(images_D)) | |
random.shuffle(images_A) | |
random.shuffle(images_B) | |
random.shuffle(images_C) | |
random.shuffle(images_D) | |
images_A,landmarks_A = load_images( images_A[:minImages] ) | |
images_B,landmarks_B = load_images( images_B[:minImages] ) | |
images_C,landmarks_C = load_images( images_C[:minImages] ) | |
images_D,landmarks_D = load_images( images_D[:minImages] ) | |
images_A = images_A/255.0 | |
images_B = images_B/255.0 | |
images_C = images_C/255.0 | |
images_D = images_D/255.0 | |
random_transform_args = { | |
'rotation_range': 10, | |
'zoom_range': 0.05, | |
'shift_range': 0.05, | |
'random_flip': 0.0, | |
} | |
conv_init = RandomNormal(0, 0.02) | |
class penalized_loss(object): | |
def __init__(self,mask,lossFunc,maskProp= 1.0): | |
self.mask = mask | |
self.lossFunc=lossFunc | |
self.maskProp = maskProp | |
self.maskaskinvProp = 1-maskProp | |
def __call__(self,y_true, y_pred): | |
tro, tgo, tbo = tf.split(y_true,3, 3 ) | |
pro, pgo, pbo = tf.split(y_pred,3, 3 ) | |
tr = tro | |
tg = tgo | |
tb = tbo | |
pr = pro | |
pg = pgo | |
pb = pbo | |
m = self.mask | |
m = m*self.maskProp | |
m += self.maskaskinvProp | |
tr *= m | |
tg *= m | |
tb *= m | |
pr *= m | |
pg *= m | |
pb *= m | |
y = tf.concat([tr, tg, tb],3) | |
p = tf.concat([pr, pg, pb],3) | |
#yo = tf.stack([tro,tgo,tbo],3) | |
#po = tf.stack([pro,pgo,pbo],3) | |
return self.lossFunc(y,p) | |
def random_transform( image, seed, rotation_range, zoom_range, shift_range, random_flip ): | |
h,w = image.shape[0:2] | |
numpy.random.seed( seed ) | |
rotation = numpy.random.uniform( -rotation_range, rotation_range ) | |
scale = numpy.random.uniform( 1 - zoom_range, 1 + zoom_range ) | |
tx = numpy.random.uniform( -shift_range, shift_range ) * w | |
ty = numpy.random.uniform( -shift_range, shift_range ) * h | |
mat = cv2.getRotationMatrix2D( (w//2,h//2), rotation, scale ) | |
mat[:,2] += (tx,ty) | |
result = cv2.warpAffine( image, mat, (w,h), borderMode=cv2.BORDER_REPLICATE ) | |
if numpy.random.random() < random_flip: | |
result = result[:,::-1] | |
return result | |
def random_warp(image): | |
range_ = numpy.linspace(0, 256, 20) | |
mapx = numpy.broadcast_to(range_, (20, 20)) | |
mapy = mapx.T | |
mapx = mapx + numpy.random.normal(size=(20, 20), scale=5) | |
mapy = mapy + numpy.random.normal(size=(20, 20), scale=5) | |
interp_mapx = cv2.resize(mapx, (256, 256)).astype('float32') | |
interp_mapy = cv2.resize(mapy, (256, 256)).astype('float32') | |
return cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR) | |
def getOpposingBestMatches(image_A,landmark_A,images_B,landmarks_B): | |
closest = ( numpy.mean(numpy.square(landmark_A-landmarks_B),axis=(1,2)) ).argsort()[0:7] | |
source = cv2.resize( image_A[:,:,:3], (64,64)) | |
closestMerged = numpy.dstack([ | |
cv2.resize( images_B[closest[0]][:,:,:3], (64,64)), | |
cv2.resize( images_B[closest[1]][:,:,:3], (64,64)), | |
cv2.resize( images_B[closest[2]][:,:,:3], (64,64)), | |
cv2.resize( images_B[closest[3]][:,:,:3], (64,64)), | |
cv2.resize( images_B[closest[4]][:,:,:3], (64,64)), | |
cv2.resize( images_B[closest[5]][:,:,:3], (64,64)), | |
]) | |
return source,closestMerged | |
def get_training_data( images,landmarks,batch_size): | |
while 1: | |
indices = numpy.random.choice(range(0,images.shape[0]),size=batch_size,replace=True) | |
for i,index in enumerate(indices): | |
image = images[index] | |
seed = int(time.time()) | |
image = random_transform( image, seed, **random_transform_args ) | |
closest = ( numpy.mean(numpy.square(landmarks[index]-landmarks),axis=(1,2)) ).argsort()[1:7] | |
closest = numpy.random.choice(closest, 6, replace=False) | |
closestMerged = numpy.dstack([ | |
cv2.resize( random_transform( images[closest[0]][:,:,:3] ,seed, **random_transform_args) , (64,64)), | |
cv2.resize( random_transform( images[closest[1]][:,:,:3] ,seed, **random_transform_args) , (64,64)), | |
cv2.resize( random_transform( images[closest[2]][:,:,:3] ,seed, **random_transform_args) , (64,64)), | |
cv2.resize( random_transform( images[closest[3]][:,:,:3] ,seed, **random_transform_args) , (64,64)), | |
cv2.resize( random_transform( images[closest[4]][:,:,:3] ,seed, **random_transform_args) , (64,64)), | |
cv2.resize( random_transform( images[closest[5]][:,:,:3] ,seed, **random_transform_args) , (64,64)), | |
]) | |
if i == 0: | |
warped_images = numpy.empty( (batch_size,) + (64,64,3), image.dtype ) | |
example_images = numpy.empty( (batch_size,) + (64,64,18), image.dtype ) | |
target_images = numpy.empty( (batch_size,) + (128,128,3), image.dtype ) | |
mask_images = numpy.empty( (batch_size,) + (128,128,1), image.dtype ) | |
warped_image = random_warp( image[:,:,:3] ) | |
warped_image = cv2.GaussianBlur( warped_image,(51,51),0 ) | |
image_mask = image[:,:,3].reshape((image.shape[0],image.shape[1],1)) * numpy.ones((image.shape[0],image.shape[1],3)).astype(float) | |
foreground = cv2.multiply(image_mask, warped_image.astype(float)) | |
background = cv2.multiply(1.0 - image_mask, image[:,:,:3].astype(float)) | |
warped_image = numpy.add(background,foreground) | |
warped_image = cv2.resize(warped_image,(64,64)) | |
warped_images[i] = warped_image | |
example_images[i] = closestMerged | |
target_images[i] = cv2.resize( image[:,:,:3], (128,128) ) | |
mask_images[i] = cv2.resize( image[:,:,3], (128,128) ).reshape((128,128,1)) | |
yield warped_images,example_images,target_images,mask_images | |
batch_size = 8 | |
class PixelShuffler(Layer): | |
def __init__(self, size=(2, 2), data_format=None, **kwargs): | |
super(PixelShuffler, self).__init__(**kwargs) | |
self.data_format = conv_utils.normalize_data_format(data_format) | |
self.size = conv_utils.normalize_tuple(size, 2, 'size') | |
def call(self, inputs): | |
input_shape = K.int_shape(inputs) | |
if len(input_shape) != 4: | |
raise ValueError('Inputs should have rank ' + | |
str(4) + | |
'; Received input shape:', str(input_shape)) | |
if self.data_format == 'channels_first': | |
batch_size, c, h, w = input_shape | |
if batch_size is None: | |
batch_size = -1 | |
rh, rw = self.size | |
oh, ow = h * rh, w * rw | |
oc = c // (rh * rw) | |
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) | |
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) | |
out = K.reshape(out, (batch_size, oc, oh, ow)) | |
return out | |
elif self.data_format == 'channels_last': | |
batch_size, h, w, c = input_shape | |
if batch_size is None: | |
batch_size = -1 | |
rh, rw = self.size | |
oh, ow = h * rh, w * rw | |
oc = c // (rh * rw) | |
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) | |
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) | |
out = K.reshape(out, (batch_size, oh, ow, oc)) | |
return out | |
def compute_output_shape(self, input_shape): | |
if len(input_shape) != 4: | |
raise ValueError('Inputs should have rank ' + | |
str(4) + | |
'; Received input shape:', str(input_shape)) | |
if self.data_format == 'channels_first': | |
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None | |
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None | |
channels = input_shape[1] // self.size[0] // self.size[1] | |
if channels * self.size[0] * self.size[1] != input_shape[1]: | |
raise ValueError('channels of input and size are incompatible') | |
return (input_shape[0], | |
channels, | |
height, | |
width) | |
elif self.data_format == 'channels_last': | |
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None | |
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None | |
channels = input_shape[3] // self.size[0] // self.size[1] | |
if channels * self.size[0] * self.size[1] != input_shape[3]: | |
raise ValueError('channels of input and size are incompatible') | |
return (input_shape[0], | |
height, | |
width, | |
channels) | |
def get_config(self): | |
config = {'size': self.size, | |
'data_format': self.data_format} | |
base_config = super(PixelShuffler, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
optimizer = Adam( lr=5e-5, beta_1=0.5, beta_2=0.999 ) | |
import numpy as np | |
def res_block(input_tensor, f): | |
x = input_tensor | |
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x) | |
x = LeakyReLU(alpha=0.2)(x) | |
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x) | |
x = Add()([x, input_tensor]) | |
x = LeakyReLU(alpha=0.2)(x) | |
return x | |
def conv( filters ): | |
def block(x): | |
x = Conv2D( filters, kernel_size=5, strides=2, padding='same' )(x) | |
x = LeakyReLU(0.1)(x) | |
return x | |
return block | |
def upscale( filters ): | |
def block(x): | |
x = Conv2D( filters*4, kernel_size=3, padding='same' )(x) | |
x = LeakyReLU(0.1)(x) | |
x = PixelShuffler()(x) | |
return x | |
return block | |
input_warped = Input( shape=(64,64,3) ) | |
input_examples = Input( shape=(64,64,18) ) | |
x = conv( 32)(input_warped) | |
x = conv( 64)(x) | |
x = conv( 128)(x) | |
x = conv( 256)(x) | |
x = res_block(x, 256) | |
e = conv( 32)(input_examples) | |
e = conv( 64)(e) | |
e = conv( 128)(e) | |
e = conv( 256)(e) | |
e = res_block(e, 256) | |
x = Dense( 2048 )( Flatten()(x) ) | |
x = Dense(4*4*1024)(x) | |
x = Reshape((4,4,1024))(x) | |
x = upscale(512)(x) | |
e = Dense( 2048 )( Flatten()(e) ) | |
e = Dense(4*4*1024)(e) | |
e = Reshape((4,4,1024))(e) | |
e = upscale(512)(e) | |
x = add([e,x]) | |
#x = concatenate([e,x]) | |
e = upscale(512)(e) | |
x = res_block(x, 512) | |
x = upscale(256)(x) | |
x = res_block(x, 256) | |
x = upscale(128)(x) | |
x = res_block(x, 128) | |
x = upscale(64)(x) | |
x = res_block(x, 64) | |
x = upscale(32)(x) | |
#x = res_block(x, 32) | |
#x = upscale(16)(x) | |
x = Conv2D( 3, kernel_size=5, padding='same', activation='sigmoid' )(x) | |
m1 = Input( shape=(128,128,1) ) | |
sumModel = Model( [input_warped,input_examples,m1], [x] ) | |
print(sumModel.summary()) | |
try: | |
sumModel.load_weights('weights.dat') | |
except: | |
print('Weights not found') | |
DSSIM = DSSIMObjective() | |
sumModel.compile( optimizer=optimizer, loss=['mae'] ) | |
from keras_contrib.losses import DSSIMObjective | |
for n in range(900000): | |
imaegsAGen = get_training_data(images_A, landmarks_A, batch_size) | |
imaegsBGen = get_training_data(images_B, landmarks_B, batch_size) | |
imaegsCGen = get_training_data(images_C, landmarks_C, batch_size) | |
imaegsDGen = get_training_data(images_D, landmarks_D, batch_size) | |
xa,xae,ya,ma = next(imaegsAGen) | |
xb,xbe,yb,mb = next(imaegsBGen) | |
xc,xce,yc,mc = next(imaegsCGen) | |
xd,xde,yd,md = next(imaegsDGen) | |
xl = numpy.concatenate((xa, xb, xc, xd), axis=0) | |
xel = numpy.concatenate((xae, xbe, xce, xde), axis=0) | |
yl = numpy.concatenate((ya, yb, yc, yd), axis=0) | |
ml = numpy.concatenate((ma, mb, mc, md), axis=0) | |
indices = numpy.random.choice(range(0,xl.shape[0]),size=xl.shape[0],replace=False) | |
print(sumModel.train_on_batch([xl[indices],xel[indices],ml[indices]], yl[indices])) | |
if n%100 == 0: | |
sumModel.save_weights('weights.dat') | |
print("saving weights") | |
if n%5 == 0: | |
testindexA = numpy.random.choice(range(0,images_A.shape[0])) | |
testindexB = numpy.random.choice(range(0,images_B.shape[0])) | |
testindexC = numpy.random.choice(range(0,images_C.shape[0])) | |
testindexD = numpy.random.choice(range(0,images_D.shape[0])) | |
zmask = numpy.zeros((1,128,128,1),float) | |
genSet = [ | |
("A->B",testindexA,images_A,images_B,landmarks_A,landmarks_B), | |
("B->A",testindexB,images_B,images_A,landmarks_B,landmarks_A), | |
("C->D",testindexC,images_C,images_D,landmarks_C,landmarks_D), | |
("D->C",testindexD,images_D,images_C,landmarks_D,landmarks_C), | |
("A->C",testindexA,images_A,images_C,landmarks_A,landmarks_C), | |
("B->D",testindexB,images_B,images_D,landmarks_B,landmarks_D), | |
] | |
figList = [] | |
for name,index,imagesSrc,imagesDst,landmarksSrc,landmarksDst in genSet: | |
testImage = imagesSrc[index] | |
landmarkSrc = landmarksSrc[index] | |
src,examp = getOpposingBestMatches(testImage, landmarkSrc, imagesDst, landmarksDst) | |
pred = sumModel.predict( [numpy.array([src]),numpy.array([examp]),zmask] ) | |
figList.append( | |
numpy.concatenate( [ | |
cv2.resize( testImage[:,:,:3], (128,128)), | |
cv2.resize( pred[0], (128,128)), | |
cv2.resize( examp[:,:,:3], (128,128)), | |
] , axis=1) | |
) | |
stack = numpy.concatenate(figList,axis=0) | |
stack = numpy.clip( stack * 255, 0, 255 ).astype('uint8') | |
stack = stack_images( stack ) | |
cv2.imshow("p", stack) | |
if cv2.waitKey(1) == ord('q'): | |
exit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment