Skip to content

Instantly share code, notes, and snippets.

@ground0state
Last active August 25, 2019 10:54
Show Gist options
  • Save ground0state/c3a3ab2c03efc096d96153bdbbc7f553 to your computer and use it in GitHub Desktop.
Save ground0state/c3a3ab2c03efc096d96153bdbbc7f553 to your computer and use it in GitHub Desktop.
import os
import glob
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator
def residual_block(input_ts):
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding='same')(input_ts)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
return Add()([x, input_ts])
def build_encoder_decoder(input_shape=(224, 224, 3)):
input_ts = Input(shape=input_shape, name='input')
# [0, 1] noamalize
x = Lambda(lambda a: a/255.)(input_ts)
# Encoder
x = Conv2D(filters=32, kernel_size=(9, 9), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=64, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
for _ in range(5):
x = residual_block(x)
# Decoder
x = Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=32, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=3, kernel_size=(9, 9), strides=(1, 1), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('tanh')(x)
# [0, 255] scale
gen_out = Lambda(lambda a: (a+1)+127.5)(x)
model_gen = Model(inputs=[input_ts], outputs=[gen_out])
return model_gen
input_shape = (224, 224, 3)
model_gen = build_encoder_decoder(input_shape=input_shape)
from google.colab import drive
DATA_DIR = ""
input_size = input_shape[:2]
img_sty = load_img(DATA_DIR+'style/Piet_Mondrian_Composition.png', target_size=input_size)
display(img_sty)
img_arr_sty = np.expand_dims(img_to_array(img_sty), axis=0)
input_sty = Input(shape=input_shape, name='input_sty')
style_outputs = []
x = Lambda(norm_vgg16)(input_sty)
for layer in vgg16.layers:
x = layer(x)
if layer.name in style_layer_names:
style_outputs.append(x)
model_sty = Model(inputs=input_sty, outputs=style_outputs)
y_true_sty = model_sty.predict(img_arr_sty)
input_con = Input(shape=input_shape, name='input_con')
contents_outputs = []
y = Lambda(norm_vgg16)(input_con)
for layer in vgg16.layers:
y = layer(y)
if layer.name in contents_layer_names:
contents_outputs.append(y)
model_con = Model(inputs=input_con, outputs=contents_outputs)
def load_imgs(img_paths, target_size=(224, 224)):
_load_img = lambda x: img_to_array(load_img(x, target_size=target_size))
img_list = [np.expand_dims(_load_img(img_path), axis=0) for img_path in img_paths]
return np.concatenate(img_list, axis=0)
def train_generator(img_paths, batch_size, model, y_true_sty, shuffle=True, epochs=None):
n_samples = len(img_paths)
indices = list(range(n_samples))
steps_per_epoch = math.ceil(n_samples/batch_size)
img_paths = np.array(img_paths)
cnt_epoch = 0
while True:
cnt_epoch += 1
if shuffle:
np.random.shuffle(indices)
for i in range(steps_per_epoch):
start = batch_size * i
end = batch_size * (i + 1)
X = load_imgs(img_paths[indices[start: end]])
batch_size_act = X.shape[0]
y_true_sty_t = [np.repeat(feat, batch_size_act, axis=0) for feat in y_true_sty]
y_true_con = model.predict(X)
yield (X, y_true_sty_t + [y_true_con])
if epochs is not None:
if cnt_epoch >= epochs:
raise StopIteration
path_glob = os.path.join(DATA_DIR+'context/*.jpg')
img_paths = glob.glob(path_glob)
batch_size = 2
epochs = 10
gen = train_generator(img_paths, batch_size, model_con, y_true_sty, epochs=epochs)
def feature_loss(y_true, y_pred):
norm = K.prod(K.cast(K.shape(y_true)[1:], 'float32'))
return K.sum(K.square(y_pred - y_true), axis=(1, 2, 3))/norm
def gram_matrix(X):
X_sw = K.permute_dimensions(X, (0, 3, 2, 1))
s = K.shape(X_sw)
new_shape = (s[0], s[1], s[2]*s[3])
X_rs = K.reshape(X_sw, new_shape)
X_rs_t = K.permute_dimensions(X_rs, (0, 2, 1))
dot = K.batch_dot(X_rs, X_rs_t)
norm = K.prod(K.cast(s[1:], 'float32'))
return dot/norm
def style_loss(y_true, y_pred):
return K.sum(K.square(gram_matrix(y_pred) - gram_matrix(y_true)), axis=(1, 2))
dt = datetime.datetime.now()
dir_log = 'model/{:%y%m%d_%H%M%S}'.format(dt)
dir_weights = 'model/{:%y%m%d_%H%M%S}/weights'.format(dt)
dir_trans = 'model/{:%y%m%d_%H%M%S}/img_trans'.format(dt)
os.makedirs(dir_log, exist_ok=True)
os.makedirs(dir_weights, exist_ok=True)
os.makedirs(dir_trans, exist_ok=True)
model.compile(
optimizer=Adadelta(),
loss=[style_loss, style_loss, style_loss, style_loss, feature_loss],
loss_weights=[1.0, 1.0, 1.0, 1.0, 3.0]
)
img_test = load_img(DATA_DIR+'test/building.jpg', target_size=input_size)
img_arr_test = img_to_array(img_test)
img_arr_test = np.expand_dims(img_to_array(img_test), axis=0)
steps_per_epoch = math.ceil(len(img_paths)/batch_size)
iters_verbose = 1000
iters_save_img = 1000
iters_save_model = steps_per_epoch
now_epoch = 0
losses = []
path_tmp = 'epoch_{}_iters_{}_loss_{:.2f}_{}'
for i, (x_train, y_train) in enumerate(gen):
if i % steps_per_epoch == 0:
now_epoch += 1
loss = model.train_on_batch(x_train, y_train)
losses.append(loss)
if i % iters_verbose == 0:
print(
'epoch:{}, iters:{}, loss:{:.3f}'.format(now_epoch, i, loss[0])
)
if i % iters_save_img == 0:
pred = model_gen.predict(img_arr_test)
img_pred = array_to_img(pred.squeeze())
path_trs_img = path_tmp.format(now_epoch, i, loss[0], '.jpg')
img_pred.save(os.path.join(dir_trans, path_trs_img))
print('# image saved:{}'.format(path_trs_img))
if i % iters_save_model == 0:
model.save(os.path.join(dir_weights, path_tmp.format(now_epoch, i, loss[0], '.h5')))
path_loss = os.path.join(dir_log, 'loss.pkl')
with open(path_loss, 'wb') as f:
pickle.dump(losses, f)
display(img_test)
pred = model_gen.predict(img_arr_test)
img_pred = array_to_img(pred.squeeze())
display(img_pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment