Last active
August 25, 2019 10:54
-
-
Save ground0state/c3a3ab2c03efc096d96153bdbbc7f553 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import math | |
import random | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tensorflow.python import keras | |
from tensorflow.python.keras import backend as K | |
from tensorflow.python.keras.models import Model, Sequential | |
from tensorflow.python.keras.layers import * | |
from tensorflow.python.keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator | |
def residual_block(input_ts): | |
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding='same')(input_ts) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding='same')(x) | |
x = BatchNormalization()(x) | |
return Add()([x, input_ts]) | |
def build_encoder_decoder(input_shape=(224, 224, 3)): | |
input_ts = Input(shape=input_shape, name='input') | |
# [0, 1] noamalize | |
x = Lambda(lambda a: a/255.)(input_ts) | |
# Encoder | |
x = Conv2D(filters=32, kernel_size=(9, 9), strides=(1, 1), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
x = Conv2D(filters=64, kernel_size=(3, 3), strides=(2, 2), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(2, 2), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
for _ in range(5): | |
x = residual_block(x) | |
# Decoder | |
x = Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(2, 2), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
x = Conv2DTranspose(filters=32, kernel_size=(3, 3), strides=(2, 2), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('relu')(x) | |
x = Conv2DTranspose(filters=3, kernel_size=(9, 9), strides=(1, 1), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = Activation('tanh')(x) | |
# [0, 255] scale | |
gen_out = Lambda(lambda a: (a+1)+127.5)(x) | |
model_gen = Model(inputs=[input_ts], outputs=[gen_out]) | |
return model_gen | |
input_shape = (224, 224, 3) | |
model_gen = build_encoder_decoder(input_shape=input_shape) | |
from google.colab import drive | |
DATA_DIR = "" | |
input_size = input_shape[:2] | |
img_sty = load_img(DATA_DIR+'style/Piet_Mondrian_Composition.png', target_size=input_size) | |
display(img_sty) | |
img_arr_sty = np.expand_dims(img_to_array(img_sty), axis=0) | |
input_sty = Input(shape=input_shape, name='input_sty') | |
style_outputs = [] | |
x = Lambda(norm_vgg16)(input_sty) | |
for layer in vgg16.layers: | |
x = layer(x) | |
if layer.name in style_layer_names: | |
style_outputs.append(x) | |
model_sty = Model(inputs=input_sty, outputs=style_outputs) | |
y_true_sty = model_sty.predict(img_arr_sty) | |
input_con = Input(shape=input_shape, name='input_con') | |
contents_outputs = [] | |
y = Lambda(norm_vgg16)(input_con) | |
for layer in vgg16.layers: | |
y = layer(y) | |
if layer.name in contents_layer_names: | |
contents_outputs.append(y) | |
model_con = Model(inputs=input_con, outputs=contents_outputs) | |
def load_imgs(img_paths, target_size=(224, 224)): | |
_load_img = lambda x: img_to_array(load_img(x, target_size=target_size)) | |
img_list = [np.expand_dims(_load_img(img_path), axis=0) for img_path in img_paths] | |
return np.concatenate(img_list, axis=0) | |
def train_generator(img_paths, batch_size, model, y_true_sty, shuffle=True, epochs=None): | |
n_samples = len(img_paths) | |
indices = list(range(n_samples)) | |
steps_per_epoch = math.ceil(n_samples/batch_size) | |
img_paths = np.array(img_paths) | |
cnt_epoch = 0 | |
while True: | |
cnt_epoch += 1 | |
if shuffle: | |
np.random.shuffle(indices) | |
for i in range(steps_per_epoch): | |
start = batch_size * i | |
end = batch_size * (i + 1) | |
X = load_imgs(img_paths[indices[start: end]]) | |
batch_size_act = X.shape[0] | |
y_true_sty_t = [np.repeat(feat, batch_size_act, axis=0) for feat in y_true_sty] | |
y_true_con = model.predict(X) | |
yield (X, y_true_sty_t + [y_true_con]) | |
if epochs is not None: | |
if cnt_epoch >= epochs: | |
raise StopIteration | |
path_glob = os.path.join(DATA_DIR+'context/*.jpg') | |
img_paths = glob.glob(path_glob) | |
batch_size = 2 | |
epochs = 10 | |
gen = train_generator(img_paths, batch_size, model_con, y_true_sty, epochs=epochs) | |
def feature_loss(y_true, y_pred): | |
norm = K.prod(K.cast(K.shape(y_true)[1:], 'float32')) | |
return K.sum(K.square(y_pred - y_true), axis=(1, 2, 3))/norm | |
def gram_matrix(X): | |
X_sw = K.permute_dimensions(X, (0, 3, 2, 1)) | |
s = K.shape(X_sw) | |
new_shape = (s[0], s[1], s[2]*s[3]) | |
X_rs = K.reshape(X_sw, new_shape) | |
X_rs_t = K.permute_dimensions(X_rs, (0, 2, 1)) | |
dot = K.batch_dot(X_rs, X_rs_t) | |
norm = K.prod(K.cast(s[1:], 'float32')) | |
return dot/norm | |
def style_loss(y_true, y_pred): | |
return K.sum(K.square(gram_matrix(y_pred) - gram_matrix(y_true)), axis=(1, 2)) | |
dt = datetime.datetime.now() | |
dir_log = 'model/{:%y%m%d_%H%M%S}'.format(dt) | |
dir_weights = 'model/{:%y%m%d_%H%M%S}/weights'.format(dt) | |
dir_trans = 'model/{:%y%m%d_%H%M%S}/img_trans'.format(dt) | |
os.makedirs(dir_log, exist_ok=True) | |
os.makedirs(dir_weights, exist_ok=True) | |
os.makedirs(dir_trans, exist_ok=True) | |
model.compile( | |
optimizer=Adadelta(), | |
loss=[style_loss, style_loss, style_loss, style_loss, feature_loss], | |
loss_weights=[1.0, 1.0, 1.0, 1.0, 3.0] | |
) | |
img_test = load_img(DATA_DIR+'test/building.jpg', target_size=input_size) | |
img_arr_test = img_to_array(img_test) | |
img_arr_test = np.expand_dims(img_to_array(img_test), axis=0) | |
steps_per_epoch = math.ceil(len(img_paths)/batch_size) | |
iters_verbose = 1000 | |
iters_save_img = 1000 | |
iters_save_model = steps_per_epoch | |
now_epoch = 0 | |
losses = [] | |
path_tmp = 'epoch_{}_iters_{}_loss_{:.2f}_{}' | |
for i, (x_train, y_train) in enumerate(gen): | |
if i % steps_per_epoch == 0: | |
now_epoch += 1 | |
loss = model.train_on_batch(x_train, y_train) | |
losses.append(loss) | |
if i % iters_verbose == 0: | |
print( | |
'epoch:{}, iters:{}, loss:{:.3f}'.format(now_epoch, i, loss[0]) | |
) | |
if i % iters_save_img == 0: | |
pred = model_gen.predict(img_arr_test) | |
img_pred = array_to_img(pred.squeeze()) | |
path_trs_img = path_tmp.format(now_epoch, i, loss[0], '.jpg') | |
img_pred.save(os.path.join(dir_trans, path_trs_img)) | |
print('# image saved:{}'.format(path_trs_img)) | |
if i % iters_save_model == 0: | |
model.save(os.path.join(dir_weights, path_tmp.format(now_epoch, i, loss[0], '.h5'))) | |
path_loss = os.path.join(dir_log, 'loss.pkl') | |
with open(path_loss, 'wb') as f: | |
pickle.dump(losses, f) | |
display(img_test) | |
pred = model_gen.predict(img_arr_test) | |
img_pred = array_to_img(pred.squeeze()) | |
display(img_pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment