Created
December 15, 2017 02:00
-
-
Save hanzhanggit/478769db513d5b0e6c6d186c295477a3 to your computer and use it in GitHub Desktop.
GAN with normalization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import numpy as np | |
import tensorflow as tf | |
from tqdm import tqdm | |
from depot import inits | |
from depot.utils import find_trainable_variables, find_variables, iter_data, shuffle | |
from depot.vis import color_grid | |
from utils.inception import get_inception_score | |
import sys | |
# from depot.load import cifar10_with_valid_set | |
desc = 'dog_128_normalize' | |
t = time.time() | |
trX = np.load('/home/hanzhang/Data/imagenet/doggonet_128px_imgs.npy') | |
print('%.3f seconds to load'%(time.time()-t)) | |
ntrain = len(trX) | |
print(trX.shape) | |
# X = tf.placeholder(tf.float32, [128, 64, 64, 3]) | |
X = tf.placeholder(tf.float32, [128, 128, 128, 3]) | |
# X = tf.placeholder(tf.float32, [128, 32, 32, 3]) | |
Z = tf.placeholder(tf.float32, [128, 100]) | |
test_batches = 5000 // (128) + 1 | |
bn_updates = [] | |
file_name = './' + desc + '_log.txt' | |
file = open(file_name, 'w+') | |
def lrelu(x, leak=0.2): | |
f1 = 0.5 * (1 + leak) | |
f2 = 0.5 * (1 - leak) | |
return f1 * x + f2 * tf.abs(x) | |
def glu(x): | |
dim = len(x.get_shape())-1 | |
a, b = tf.split(x, 2, dim) | |
return a*tf.nn.sigmoid(b) | |
def _bn(x, g, b, e=1e-5, axes=[1], ema=None): | |
shape = [s.value for s in x.get_shape()] | |
for axis in axes: | |
shape[axis] = 1 | |
uv = tf.get_variable("u", shape, initializer=inits.constant_init(0.0), trainable=False) | |
sv = tf.get_variable("s", shape, initializer=inits.constant_init(1.0), trainable=False) | |
if ema is not None: | |
u = ema.average(uv) | |
s = ema.average(sv) | |
else: | |
u, s = tf.nn.moments(x, axes=axes, keep_dims=True) | |
bn_updates.append(uv.assign(u)) | |
bn_updates.append(sv.assign(s)) | |
x = (x-u)/tf.sqrt(s+e) | |
x = x*g+b | |
return x | |
def conv(x, scope, rf, nf, act, stride=1, pad='SAME', winit=inits.ortho_init(1.0), binit=inits.constant_init(0.0), ema=None): | |
with tf.variable_scope(scope): | |
nin = x.get_shape()[-1].value | |
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=winit) | |
b = tf.get_variable("b", [nf], initializer=binit) | |
if ema is not None: | |
w = ema.average(w) | |
b = ema.average(b) | |
z = tf.nn.conv2d(x, w, [1, stride, stride, 1], padding=pad) | |
z = z+b | |
h = act(z) | |
return h | |
def bnconv(x, scope, rf, nf, act, stride=1, pad='SAME', winit=inits.ortho_init(1.0), ema=None): | |
with tf.variable_scope(scope): | |
nin = x.get_shape()[-1].value | |
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=winit) | |
g = tf.get_variable("g", [nf], initializer=inits.constant_init(1.0)) | |
b = tf.get_variable("b", [nf], initializer=inits.constant_init(0.0)) | |
if ema is not None: | |
w = ema.average(w) | |
g = ema.average(g) | |
b = ema.average(b) | |
z = tf.nn.conv2d(x, w, [1, stride, stride, 1], padding=pad) | |
z = _bn(z, g, b, axes=[0, 1, 2], ema=ema) | |
h = act(z) | |
return h | |
def deconv(x, scope, shape, rf, nf, act, stride=2, pad='SAME', winit=inits.ortho_init(1.0), binit=inits.constant_init(0.0), ema=None): | |
with tf.variable_scope(scope): | |
nin = x.get_shape()[-1].value | |
w = tf.get_variable("w", [rf, rf, nf, nin], initializer=winit) | |
b = tf.get_variable("b", [nf], initializer=binit) | |
if ema is not None: | |
w = ema.average(w) | |
b = ema.average(b) | |
z = tf.nn.conv2d_transpose(x, w, shape, [1, stride, stride, 1], padding=pad) | |
z = z+b | |
h = act(z) | |
return h | |
def bndeconv(x, scope, shape, rf, nf, act, stride=2, pad='SAME', winit=inits.ortho_init(1.0), ema=None): | |
with tf.variable_scope(scope): | |
nin = x.get_shape()[-1].value | |
w = tf.get_variable("w", [rf, rf, nf, nin], initializer=winit) | |
g = tf.get_variable("g", [nf], initializer=inits.constant_init(1.0)) | |
b = tf.get_variable("b", [nf], initializer=inits.constant_init(0.0)) | |
if ema is not None: | |
w = ema.average(w) | |
g = ema.average(g) | |
b = ema.average(b) | |
z = tf.nn.conv2d_transpose(x, w, shape, [1, stride, stride, 1], padding=pad) | |
z = _bn(z, g, b, axes=[0, 1, 2], ema=ema) | |
h = act(z) | |
return h | |
def bnfc(x, scope, nh, act, ema=None): | |
with tf.variable_scope(scope): | |
nin = x.get_shape()[1].value | |
w = tf.get_variable("w", [nin, nh], initializer=inits.ortho_init(1.0)) | |
g = tf.get_variable("g", [nh], initializer=inits.constant_init(1.0)) | |
b = tf.get_variable("b", [nh], initializer=inits.constant_init(0.0)) | |
if ema is not None: | |
w = ema.average(w) | |
g = ema.average(g) | |
b = ema.average(b) | |
z = tf.matmul(x, w) | |
z = _bn(z, g, b, axes=[0], ema=ema) | |
h = act(z) | |
return h | |
def fc(x, scope, nh, act, winit=inits.ortho_init(1.0), binit=inits.constant_init(0.0), ema=None): | |
with tf.variable_scope(scope): | |
nin = x.get_shape()[1].value | |
w = tf.get_variable("w", [nin, nh], initializer=winit) | |
b = tf.get_variable("b", [nh], initializer=binit) | |
if ema is not None: | |
w = ema.average(w) | |
b = ema.average(b) | |
z = tf.matmul(x, w) | |
z = z+b | |
h = act(z) | |
return h | |
def generator(Z, reuse=False, ema=None): | |
with tf.variable_scope('generator', reuse=reuse): | |
h = bnfc(Z, scope='h', nh=4*4*1024, act=glu, ema=ema) | |
h = tf.reshape(h, [128, 4, 4, 512]) | |
h2 = bndeconv(h, scope='h2', shape=[128, 8, 8, 512], rf=5, nf=512, act=glu, ema=ema) | |
h3 = bndeconv(h2, scope='h3', shape=[128, 16, 16, 256], rf=5, nf=256, act=glu, ema=ema) | |
h4 = bndeconv(h3, scope='h4', shape=[128, 32, 32, 128], rf=5, nf=128, act=glu, ema=ema) | |
h5 = bndeconv(h4, scope='h5', shape=[128, 64, 64, 64], rf=5, nf=64, act=glu, ema=ema) | |
h6 = deconv(h5, scope='h6', shape=[128, 128, 128, 3], rf=5, nf=3, act=tf.nn.tanh, ema=ema) | |
return h6 | |
def discriminator(X, reuse=False): | |
with tf.variable_scope('discriminator', reuse=reuse): | |
h = conv(X, scope='h', rf=5, nf=32, act=lrelu, stride=2) | |
h2 = bnconv(h, scope='h2', rf=5, nf=64, act=lrelu, stride=2) | |
h3 = bnconv(h2, scope='h3', rf=5, nf=128, act=lrelu, stride=2) | |
h4 = bnconv(h3, scope='h4', rf=5, nf=256, act=lrelu, stride=2) | |
h5 = bnconv(h4, scope='h5', rf=5, nf=512, act=lrelu, stride=2) | |
h5 = tf.reshape(h5, [128, -1]) | |
logits = fc(h5, scope='out', nh=1, act=lambda x:x, winit=inits.ortho_init(1.0)) | |
return logits | |
gz = generator(Z) | |
dx = discriminator(X) | |
dgz = discriminator(gz, reuse=True) | |
ema_params = find_variables('generator') | |
for p in ema_params: | |
print(p) | |
ema = tf.train.ExponentialMovingAverage(decay=0.999) | |
avg_params = ema.apply(ema_params) | |
ema_params = [ema.average(p) for p in ema_params] | |
gz_ema = generator(Z, reuse=True, ema=ema) | |
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dgz, labels=tf.ones((128, 1)))) | |
dx_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dx, labels=tf.ones((128, 1)))) | |
dgz_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=dgz, labels=tf.zeros((128, 1)))) | |
d_loss = dx_loss*0.5 + dgz_loss*0.5 | |
d_params = find_trainable_variables('discriminator') | |
g_params = find_trainable_variables('generator') | |
for p in g_params: | |
print(p.name) | |
for p in d_params: | |
print(p.name) | |
d_grads = tf.gradients(d_loss, d_params) | |
g_grads = tf.gradients(g_loss, g_params) | |
for i in range(len(d_grads)): | |
d_grads[i] = d_grads[i] / tf.norm(d_grads[i]) | |
for i in range(len(g_grads)): | |
g_grads[i] = g_grads[i] / tf.norm(g_grads[i]) | |
d_trainer = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5) | |
g_trainer = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5) | |
d_train = d_trainer.apply_gradients(list(zip(d_grads, d_params))) | |
g_train = g_trainer.apply_gradients(list(zip(g_grads, g_params))) | |
bn_updates = [bn_update for bn_update in bn_updates if 'generator' in bn_update.name] | |
bn_updates = tf.group(*bn_updates) | |
sample_zmb = np.random.randn(128, 100).astype(np.float32) | |
nepochs = 0 | |
nupdates = 0 | |
nseconds = 0 | |
config = tf.ConfigProto(allow_soft_placement=True, | |
intra_op_parallelism_threads=4, | |
inter_op_parallelism_threads=4) | |
with tf.Session(config=config) as sess: | |
tf.global_variables_initializer().run() | |
samples = sess.run(gz, {Z:sample_zmb}) | |
print(samples.mean(), samples.std(), samples.min(), samples.max()) | |
img = color_grid((samples+1)/2., path='vis/%s/init.png'%desc) | |
tstart = time.time() | |
for i in range(1000): | |
for xmb in tqdm(iter_data(*shuffle(trX), size=128), total=ntrain//128, leave=False, ncols=80): | |
if len(xmb) == 128: | |
zmb = np.random.randn(128, 100).astype(np.float32) | |
sess.run(d_train, {X:xmb/127.5-1., Z:zmb}) | |
zmb = np.random.randn(128, 100).astype(np.float32) | |
sess.run([g_train, avg_params, bn_updates], {Z:zmb}) | |
nupdates += 1 | |
nseconds = (time.time()-tstart) | |
zmb = np.random.randn(128, 100).astype(np.float32) | |
xmb = trX[:128] | |
print(i) | |
samples = sess.run(gz, {Z:sample_zmb}) | |
img = color_grid((samples+1)/2., path='vis/%s/cur/%d.png'%(desc, i)) | |
samples = sess.run(gz_ema, {Z:sample_zmb}) | |
img = color_grid((samples+1)/2., path='vis/%s/ema/%d.png'%(desc, i)) | |
test_sample = [] | |
test_sample_ema = [] | |
if (i+1) % 30 == 0: | |
for t in range(test_batches): | |
test_zmb = np.random.randn(128, 100).astype(np.float32) | |
samples = sess.run(gz, {Z: test_zmb}) | |
samples_ema = sess.run(gz_ema, {Z:test_zmb}) | |
test_sample.append(samples) | |
test_sample_ema.append(samples_ema) | |
test_sample = np.concatenate(test_sample) | |
test_sample_ema = np.concatenate(test_sample_ema) | |
test_sample = [127.5*(test_sample[i]+1.) for i in range(test_sample.shape[0])] | |
test_sample_ema = [127.5 * (test_sample_ema[i] + 1.) for i in range(test_sample_ema.shape[0])] | |
inception_score = get_inception_score(test_sample, splits=1) | |
file.write('epoch %d inception score was %.6f \n' % (i, inception_score[0])) | |
inception_score = get_inception_score(test_sample_ema, splits=1) | |
file.write('epoch %d EMA inception score was %.6f\n\n' % (i, inception_score[0])) | |
file.flush() | |
file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment