|
# |
|
# mnist_ae2.py date. 7/4/2016 |
|
# |
|
# Autoencoder tutorial code - trial of convolutional AE |
|
# |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
import matplotlib as mpl |
|
mpl.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import tensorflow as tf |
|
|
|
from tensorflow.examples.tutorials.mnist import input_data |
|
from my_nn_lib import Convolution2D, MaxPooling2D |
|
from my_nn_lib import FullConnected, ReadOutLayer |
|
|
|
# Up-sampling 2-D Layer (deconvolutoinal Layer) |
|
class Conv2Dtranspose(object): |
|
''' |
|
constructor's args: |
|
input : input image (2D matrix) |
|
output_siz : output image size |
|
in_ch : number of incoming image channel |
|
out_ch : number of outgoing image channel |
|
patch_siz : filter(patch) size |
|
''' |
|
def __init__(self, input, output_siz, in_ch, out_ch, patch_siz, activation='relu'): |
|
self.input = input |
|
self.rows = output_siz[0] |
|
self.cols = output_siz[1] |
|
self.out_ch = out_ch |
|
self.activation = activation |
|
|
|
wshape = [patch_siz[0], patch_siz[1], out_ch, in_ch] # note the arguments order |
|
|
|
w_cvt = tf.Variable(tf.truncated_normal(wshape, stddev=0.1), |
|
trainable=True) |
|
b_cvt = tf.Variable(tf.constant(0.1, shape=[out_ch]), |
|
trainable=True) |
|
self.batsiz = tf.shape(input)[0] |
|
self.w = w_cvt |
|
self.b = b_cvt |
|
self.params = [self.w, self.b] |
|
|
|
def output(self): |
|
shape4D = [self.batsiz, self.rows, self.cols, self.out_ch] |
|
linout = tf.nn.conv2d_transpose(self.input, self.w, output_shape=shape4D, |
|
strides=[1, 2, 2, 1], padding='SAME') + self.b |
|
if self.activation == 'relu': |
|
self.output = tf.nn.relu(linout) |
|
elif self.activation == 'sigmoid': |
|
self.output = tf.sigmoid(linout) |
|
else: |
|
self.output = linout |
|
|
|
return self.output |
|
|
|
# Create the model |
|
def model(X, w_e, b_e, w_d, b_d): |
|
encoded = tf.sigmoid(tf.matmul(X, w_e) + b_e) |
|
decoded = tf.sigmoid(tf.matmul(encoded, w_d) + b_d) |
|
|
|
return encoded, decoded |
|
|
|
def mk_nn_model(x, y_): |
|
# Encoding phase |
|
x_image = tf.reshape(x, [-1, 28, 28, 1]) |
|
conv1 = Convolution2D(x_image, (28, 28), 1, 16, |
|
(3, 3), activation='relu') |
|
conv1_out = conv1.output() |
|
|
|
pool1 = MaxPooling2D(conv1_out) |
|
pool1_out = pool1.output() |
|
|
|
conv2 = Convolution2D(pool1_out, (14, 14), 16, 8, |
|
(3, 3), activation='relu') |
|
conv2_out = conv2.output() |
|
|
|
pool2 = MaxPooling2D(conv2_out) |
|
pool2_out = pool2.output() |
|
|
|
conv3 = Convolution2D(pool2_out, (7, 7), 8, 8, (3, 3), activation='relu') |
|
conv3_out = conv3.output() |
|
|
|
pool3 = MaxPooling2D(conv3_out) |
|
pool3_out = pool3.output() |
|
# at this point the representation is (8, 4, 4) i.e. 128-dimensional |
|
# Decoding phase |
|
conv_t1 = Conv2Dtranspose(pool3_out, (7, 7), 8, 8, |
|
(3, 3), activation='relu') |
|
conv_t1_out = conv_t1.output() |
|
|
|
conv_t2 = Conv2Dtranspose(conv_t1_out, (14, 14), 8, 8, |
|
(3, 3), activation='relu') |
|
conv_t2_out = conv_t2.output() |
|
|
|
conv_t3 = Conv2Dtranspose(conv_t2_out, (28, 28), 8, 16, |
|
(3, 3), activation='relu') |
|
conv_t3_out = conv_t3.output() |
|
|
|
conv_last = Convolution2D(conv_t3_out, (28, 28), 16, 1, (3, 3), |
|
activation='sigmoid') |
|
decoded = conv_last.output() |
|
|
|
decoded = tf.reshape(decoded, [-1, 784]) |
|
cross_entropy = -1. *x *tf.log(decoded) - (1. - x) *tf.log(1. - decoded) |
|
loss = tf.reduce_mean(cross_entropy) |
|
|
|
return loss, decoded |
|
|
|
|
|
if __name__ == '__main__': |
|
mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True) |
|
# Variables |
|
x = tf.placeholder(tf.float32, [None, 784]) |
|
y_ = tf.placeholder(tf.float32, [None, 10]) |
|
|
|
loss, decoded = mk_nn_model(x, y_) |
|
train_step = tf.train.AdagradOptimizer(0.1).minimize(loss) |
|
|
|
init = tf.initialize_all_variables() |
|
# Train |
|
with tf.Session() as sess: |
|
sess.run(init) |
|
print('Training...') |
|
for i in range(10001): |
|
batch_xs, batch_ys = mnist.train.next_batch(128) |
|
train_step.run({x: batch_xs, y_: batch_ys}) |
|
if i % 1000 == 0: |
|
train_loss= loss.eval({x: batch_xs, y_: batch_ys}) |
|
print(' step, loss = %6d: %6.3f' % (i, train_loss)) |
|
|
|
# generate decoded image with test data |
|
test_fd = {x: mnist.test.images, y_: mnist.test.labels} |
|
decoded_imgs = decoded.eval(test_fd) |
|
print('loss (test) = ', loss.eval(test_fd)) |
|
|
|
x_test = mnist.test.images |
|
n = 10 # how many digits we will display |
|
plt.figure(figsize=(20, 4)) |
|
for i in range(n): |
|
# display original |
|
ax = plt.subplot(2, n, i + 1) |
|
plt.imshow(x_test[i].reshape(28, 28)) |
|
plt.gray() |
|
ax.get_xaxis().set_visible(False) |
|
ax.get_yaxis().set_visible(False) |
|
|
|
# display reconstruction |
|
ax = plt.subplot(2, n, i + 1 + n) |
|
plt.imshow(decoded_imgs[i].reshape(28, 28)) |
|
plt.gray() |
|
ax.get_xaxis().set_visible(False) |
|
ax.get_yaxis().set_visible(False) |
|
|
|
#plt.show() |
|
plt.savefig('mnist_ae2.png') |
|
|
Hi Kajiyu,
I thought I implemented the unpooling process by tf.nn.conv2d_transpose() according to the stackoverflow Q&A information.
http://stackoverflow.com/questions/37926562/tensorflow-conv2d-transpose-deconv-number-of-rows-of-out-backprop-doesnt-matc/
In this Q&A info, you can find the link to very instructive slide in bottom of page.
Tomo