Last active
June 21, 2017 09:59
-
-
Save scturtle/4025dd920284becbf5150bd527e04544 to your computer and use it in GitHub Desktop.
GAN a Gaussian distribution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import numpy as np | |
from scipy.stats import norm | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
import tensorflow.contrib.layers as tfl | |
import tensorflow.contrib.framework as tff | |
def minibatch(features): | |
size_a, size_b, size_c = features.get_shape()[1], 3, 3 | |
t = tf.get_variable('T', (size_a, size_b * size_c), | |
initializer=tf.random_normal_initializer()) | |
m = tf.reshape(tf.matmul(features, t), (-1, size_b, size_c)) | |
diff = tf.expand_dims(m, 3) - tf.expand_dims(tf.transpose(m, (1, 2, 0)), 0) | |
l1 = tf.reduce_mean(tf.abs(diff), 2) | |
o = tf.reduce_mean(tf.exp(-l1), 2) | |
return tf.concat(values=[features, o], axis=1) | |
def network(inputs, *, use_minibatch=False, last_fn=tf.nn.sigmoid): | |
w_init = tf.random_normal_initializer() | |
b_init = tf.constant_initializer(0.) | |
h0 = tfl.fully_connected( | |
inputs, 10, scope='h0', activation_fn=tf.nn.tanh, | |
weights_initializer=w_init, biases_initializer=b_init) | |
h1 = tfl.fully_connected( | |
h0, 10, scope='h1', activation_fn=tf.nn.tanh, | |
weights_initializer=w_init, biases_initializer=b_init) | |
if use_minibatch: | |
h1 = minibatch(h1) | |
h2 = tfl.fully_connected( | |
h1, 1, scope='h2', activation_fn=last_fn, | |
weights_initializer=w_init, biases_initializer=b_init) | |
return h2 | |
def optimizer(loss, lr, var_list): | |
lr_gen = tf.train.exponential_decay(lr, tff.get_or_create_global_step(), 1000, 0.95, staircase=True) | |
return tf.train.MomentumOptimizer(lr_gen, 0.5).minimize(loss, tff.get_or_create_global_step(), var_list) | |
def main(): | |
with tf.variable_scope('G'): | |
z_input = tf.placeholder(shape=(None, 1), dtype=tf.float32) | |
G = network(z_input, last_fn=None) | |
with tf.variable_scope('D') as scope: | |
x_input = tf.placeholder(shape=(None, 1), dtype=tf.float32) | |
use_minibatch = True | |
D1 = network(x_input, use_minibatch=use_minibatch) | |
scope.reuse_variables() | |
D2 = network(G, use_minibatch=use_minibatch) | |
x_label = tf.placeholder(shape=(None, 1), dtype=tf.float32) | |
loss_dp = tf.reduce_mean(tf.square(D1 - x_label)) | |
loss_d = tf.reduce_mean(-tf.log(D1) - tf.log(1 - D2)) | |
loss_g = tf.reduce_mean(-tf.log(D2)) | |
var_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') | |
opt_dp = optimizer(loss_dp, 0.3, var_d) | |
opt_d = optimizer(loss_d, 0.03, var_d) | |
var_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') | |
opt_g = optimizer(loss_g, 0.01, var_g) | |
with tf.Session() as sess: | |
mu = 2 | |
sigma = 0.5 | |
tf.global_variables_initializer().run() | |
batch_size = 128 | |
pre_train = True | |
if pre_train: | |
for step in range(2000): | |
x = np.random.uniform(-2, 6, batch_size) | |
x.sort() | |
x = x.reshape((-1, 1)) | |
y = norm.pdf(x, loc=mu, scale=sigma).reshape((-1, 1)) | |
ldp, _ = sess.run([loss_dp, opt_dp], {x_input: x, x_label: y}) | |
print('[D_pre] step: {} loss: {}'.format(step, ldp)) | |
tx = np.linspace(0, 4, 200) | |
ty = norm.pdf(tx, loc=mu, scale=sigma) | |
pred = sess.run(D1, {x_input: tx.reshape((-1, 1))}).ravel() | |
plt.plot(tx, ty, label='real') | |
plt.plot(tx, pred, label='pred') | |
plt.legend() | |
plt.show() | |
plt.ion() | |
batch_size = 32 | |
for step in range(3000): | |
x = np.random.normal(mu, sigma, batch_size) | |
x.sort() | |
x = x.reshape((-1, 1)) | |
z = np.random.uniform(0, 4, batch_size) | |
z.sort() | |
z = z.reshape((-1, 1)) | |
ld, _ = sess.run([loss_d, opt_d], {x_input: x, z_input: z}) | |
lg, _ = sess.run([loss_g, opt_g], {z_input: z}) | |
print('[GAN] step: {} lg: {} ld: {}'.format(step, lg, ld)) | |
if step and step % 100 == 0: | |
xs = [] | |
gxs = [] | |
for i in range(3000): | |
z = np.random.uniform(0, 4, batch_size) | |
z.sort() | |
z = z.reshape((-1, 1)) | |
gxs.append(sess.run(G, {z_input: z})) | |
xs.append(np.random.normal(mu, sigma, batch_size)) | |
gxs = np.concatenate(gxs) | |
xs = np.concatenate(xs) | |
plt.cla() | |
plt.hist(xs, bins=100, alpha=0.8, label='real') | |
plt.hist(gxs, bins=100, alpha=0.8, label='gan') | |
plt.legend() | |
plt.pause(0.1) | |
plt.ioff() | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment