Last active
January 25, 2019 09:37
-
-
Save Elfsong/aa38a85fcb5c1d20ca781fe408d8c4ac to your computer and use it in GitHub Desktop.
MNIST_GAN #GAN #python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os # Get system method | |
import shutil # Recursive traversing file path | |
import tensorflow as tf # Neural Network2 | |
import numpy as np # Matrix Computing | |
from skimage.io import imsave # Image operation | |
from tensorflow.examples.tutorials.mnist import input_data | |
data = input_data.read_data_sets('MNIST_data/') | |
# The size of image is (28, 28, 1) | |
image_height = 28 | |
image_width = 28 | |
image_size = image_height * image_width | |
# Train/ Restore / Output Path | |
train = True | |
restore = False | |
output_path = "./output/" | |
# Hyper Parameters | |
max_epoch = 500 | |
batch_size = 256 | |
z_size = 128 # Input size | |
h1_size = 256 # Hidden layer 1 size | |
h2_size = 512 # Hidden layer 2 size | |
def load_data(data_path): | |
""" | |
Loading MNIST data | |
:param data_path: MNIST data path | |
:return train_data: (60000, 28, 28, 1) | |
:return train_label: (60000, 1) | |
""" | |
f_data = open(os.path.join(data_path, 'train-images.idx3-ubyte')) | |
loaded_data = np.fromfile(file=f_data, dtype=np.uint8) | |
# The first 16 bits should be skipped | |
train_data = loaded_data[16:].reshape((-1, 784)).astype(np.float) | |
f_label = open(os.path.join(data_path, 'train-labels.idx1-ubyte')) | |
loaded_label = np.fromfile(file=f_label, dtype=np.uint8) | |
# The first 8 bits should be skipped | |
train_label = loaded_label[8:].reshape((-1)).astype(np.float) | |
return train_data, train_label | |
def generator(z_prior): | |
""" | |
Generating image | |
:param z_prior: The input of random noise matrix (batch_size, z_size) | |
:return x_generate: The generated image | |
:return g_params: All parameters of the generator | |
""" | |
# The first hidden layer | |
w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32) | |
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32) | |
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1) | |
# The second hidden layer | |
w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32) | |
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32) | |
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2) | |
# The third hidden layer | |
w3 = tf.Variable(tf.truncated_normal([h2_size, image_size], stddev=0.1), name="g_w3", dtype=tf.float32) | |
b3 = tf.Variable(tf.zeros([image_size]), name="g_b3", dtype=tf.float32) | |
x_generate = tf.nn.tanh(tf.matmul(h2, w3) + b3) | |
# All parameters of the generator | |
g_params = [w1, b1, w2, b2, w3, b3] | |
return x_generate, g_params | |
# 定义GAN的判别器 | |
def discriminator(x_data, x_generated, keep_prob): | |
""" | |
Discriminating image | |
:param x_data: Real data | |
:param x_generated: Generated data | |
:param keep_prob: Dropout rate | |
:return y_data: result for real data | |
:return y_generated: result for generated data | |
:return d_params: All parameters of the discriminator | |
""" | |
# Merging the real data and the generated data | |
x_in = tf.concat([x_data, x_generated], 0) | |
# The first hidden layer | |
w1 = tf.Variable(tf.truncated_normal([image_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32) | |
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32) | |
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob) | |
# The second hidden layer | |
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32) | |
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32) | |
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob) | |
# The third hidden layer | |
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32) | |
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32) | |
h3 = tf.matmul(h2, w3) + b3 | |
# Get batch_size images | |
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None)) | |
# Get remained images | |
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None)) | |
# All parameters of the discriminator | |
d_params = [w1, b1, w2, b2, w3, b3] | |
return y_data, y_generated, d_params | |
def show_result(batch_result, fname, grid_size=(8, 8), grid_pad=5): | |
""" | |
Showing the result | |
:param batch_result: Batch size image input | |
:param fname: input path | |
:param grid_size: Output image size (default 8*8) | |
:param grid_pad: Output padding (default 5 pixels) | |
:return: None | |
""" | |
# Regularisation / Reshape (batch_size, image_height, image_width) | |
batch_res = 0.5 * batch_result.reshape((batch_result.shape[0], image_height, image_width)) + 0.5 | |
img_h, img_w = batch_res.shape[1], batch_res.shape[2] | |
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1) | |
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1) | |
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8) | |
for i, res in enumerate(batch_res): | |
if i >= grid_size[0] * grid_size[1]: | |
break | |
img = (res) * 255. | |
img = img.astype(np.uint8) | |
row = (i // grid_size[0]) * (img_h + grid_pad) | |
col = (i % grid_size[1]) * (img_w + grid_pad) | |
img_grid[row:row + img_h, col:col + img_w] = img | |
imsave(fname, img_grid) | |
# 定义训练过程 | |
def train(): | |
''' | |
函数功能:训练整个GAN网络,并随机生成手写数字 | |
输入:无 | |
输出:sess.saver() | |
''' | |
# 加载数据 | |
train_data, train_label = load_data("MNIST_data") | |
size = train_data.shape[0] | |
# 构建模型--------------------------------------------------------------------- | |
# 定义GAN网络的输入,其中x_data为[batch_size, image_size], z_prior为[batch_size, z_size] | |
x_data = tf.placeholder(tf.float32, [batch_size, image_size], name="x_data") # (batch_size, image_size) | |
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior") # (batch_size, z_size) | |
# 定义dropout率 | |
keep_prob = tf.placeholder(tf.float32, name="keep_prob") | |
global_step = tf.Variable(0, name="global_step", trainable=False) | |
# 利用生成器生成数据x_generated和参数g_params | |
x_generated, g_params = generator(z_prior) | |
# 利用判别器判别生成器的结果 | |
y_data, y_generated, d_params = discriminator(x_data, x_generated, keep_prob) | |
# 定义判别器和生成器的loss函数 | |
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated)) | |
g_loss = - tf.log(y_generated) | |
# 设置学习率为0.0001,用AdamOptimizer进行优化 | |
optimizer = tf.train.AdamOptimizer(0.0001) | |
# 判别器discriminator 和生成器 generator 对损失函数进行最小化处理 | |
d_trainer = optimizer.minimize(d_loss, var_list=d_params) | |
g_trainer = optimizer.minimize(g_loss, var_list=g_params) | |
# 模型构建完毕-------------------------------------------------------------------- | |
# 全局变量初始化 | |
init = tf.global_variables_initializer() | |
# 启动会话sess | |
saver = tf.train.Saver() | |
sess = tf.Session() | |
sess.run(init) | |
# 判断是否需要存储 | |
if restore: | |
# 若是,将最近一次的checkpoint点存到outpath下 | |
chkpt_fname = tf.train.latest_checkpoint(output_path) | |
saver.restore(sess, chkpt_fname) | |
else: | |
# 若否,判断目录是存在,如果目录存在,则递归的删除目录下的所有内容,并重新建立目录 | |
if os.path.exists(output_path): | |
shutil.rmtree(output_path) | |
os.mkdir(output_path) | |
# 利用随机正态分布产生噪声影像,尺寸为(batch_size, z_size) | |
z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) | |
# 逐个epoch内训练 | |
for i in range(max_epoch): | |
# 图像每个epoch内可以放(size // batch_size)个size | |
for j in range(size // batch_size): | |
if j % 50 == 0: | |
print("epoch:%s, iter:%s" % (i, j)) | |
# 训练一个batch的数据 | |
batch_end = j * batch_size + batch_size | |
if batch_end >= size: | |
batch_end = size - 1 | |
x_value = train_data[j * batch_size: batch_end] | |
# 将数据归一化到[-1, 1] | |
x_value = x_value / 255. | |
x_value = 2 * x_value - 1 | |
# 以正太分布的形式产生随机噪声 | |
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) | |
# 每个batch下,输入数据运行GAN,训练判别器 | |
if j % 1 == 0: | |
sess.run(d_trainer, | |
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)}) | |
# 每个batch下,输入数据运行GAN,训练生成器 | |
if j % 1 == 0: | |
sess.run(g_trainer, | |
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)}) | |
# 每一个epoch中的所有batch训练完后,利用z_sample测试训练后的生成器 | |
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val}) | |
# 每一个epoch中的所有batch训练完后,显示生成器的结果,并打印生成结果的值 | |
show_result(x_gen_val, os.path.join(output_path, "sample%s.jpg" % i)) | |
print(x_gen_val) | |
# 每一个epoch中,生成随机分布以重置z_random_sample_val | |
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) | |
# 每一个epoch中,利用z_random_sample_val生成手写数字图像,并显示结果 | |
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val}) | |
show_result(x_gen_val, os.path.join(output_path, "random_sample%s.jpg" % i)) | |
# 保存会话 | |
sess.run(tf.assign(global_step, i + 1)) | |
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step) | |
if __name__ == '__main__': | |
if train: | |
train() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment