Skip to content

Instantly share code, notes, and snippets.

@mlzxy
Created November 25, 2016 19:28
Show Gist options
  • Save mlzxy/d8c843863873b2d14eb22a263cfef847 to your computer and use it in GitHub Desktop.
Save mlzxy/d8c843863873b2d14eb22a263cfef847 to your computer and use it in GitHub Desktop.
gan.py
from util import *
from tf_util import *
from matplotlib import pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# %matplotlib inline
"""
Parameter Sections
"""
model_name = 'basic_gan_mnist'
log = Log(model_name)
model_path = make_model_path(model_name)
images_path = join(data__, 'mnist_data/gen_gan_images')
mkdir(images_path)
n_step = 20000
test_freq = 100
batch_size = 32*3
z_dim_list = [8, 8, 1]
z_dim = list_prod(z_dim_list)
lr = 0.0002
def g_net(z):
# 8, 8, 1
h = tf_reshape_for_conv(z, z_dim_list)
h = tf_conv2d(h, 128, k_w=3, k_h=3, name="g_conv_1", activation=tf_relu, bn=True)
h = tf_deconv2d(h, [14, 14, 64], k_w=7, k_h=7,
name="g_deconv2d_2", activation=tf_relu, bn=True)
h = tf_deconv2d(h, [20, 20, 32], k_w=7, k_h=7,
name="g_deconv2d_3", activation=tf_relu, bn=True)
h = tf_deconv2d(h, [25, 25, 16], k_w=6, k_h=6,
name="g_deconv2d_4", activation=tf_relu, bn=True)
h = tf_deconv2d(h, [28, 28, 8], k_w=4, k_h=4,
name="g_deconv2d_5", activation=tf_relu, bn=True)
h = tf_conv2d(h, 1, k_w=1, k_h=1, name="g_conv2d_6", activation=tf_sigmoid)
return h
def d_net(input_var):
h = tf_conv2d(input_var, 128, k_w=3, k_h=3, name="d_conv_1", activation=tf_leaky_relu(0.2), bn=True)
h = tf_max_pool2d(h, d_h=2, d_w=2, name="pool_1") # max pooling 创建分类abstraction, 看来是成功的关键所在呀。。。
h = tf_dropout(h, 0.5)
h = tf_conv2d(h, 64, k_w=3, k_h=3, name="d_conv_2", activation=tf_leaky_relu(0.2), bn=True)
h = tf_max_pool2d(h, d_h=2, d_w=2, name="pool_2")
h = tf_dropout(h, 0.5)
h = tf_conv2d(h, 32, k_w=3, k_h=3, name="d_conv_3", activation=tf_leaky_relu(0.2), bn=True)
h = tf_max_pool2d(h, d_h=2, d_w=2, name="pool_3")
h = tf_conv2d(h, 16, k_w=3, k_h=3, name="d_conv_4", activation=tf_leaky_relu(0.2), bn=True)
h = tf_flatten_for_dense(h)
h = tf_dropout(h, 0.5)
h = tf_dense(h, 1, name="dense_1")
return h, tf_sigmoid(h)
def test_fun(model=None, n_iter=0, losses=[]):
plot_loss(losses, title=model_name+"_loss",
save_to=join(images_path, '{0}_loss.png'.format(n_iter)))
gen_images = model.generate_data(batch_size=16)
plt.figure(figsize=(12, 12))
dim = (8, 8)
for i, image in enumerate(gen_images):
plt.subplot(dim[0], dim[1], i + 1)
plt.imshow(image.reshape((28, 28)))
plt.axis('off')
plt.tight_layout()
plt.savefig(join(images_path, '{0}_gen.png'.format(n_iter)))
plt.show()
with tf.Session() as session:
log('Loading Data')
mnist = input_data.read_data_sets(join(data__, 'mnist_data'), one_hot=True)
log('Finished!')
log('Building Graph')
gan = GanModel(input_dim=(28, 28, 1), z_dim=z_dim,
g_net=g_net,
d_net=d_net,
optimizer=lambda: tf.train.AdamOptimizer(lr, beta1=0.5),
name=model_name,
session=session,
next_batch=mnist.train.next_batch, reuse=False,
batch_size=batch_size, log=log, test_freq=test_freq, test_fun=test_fun)
log('Finished!')
loss_list = gan.train(log=None)
ok()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment