Last active
July 8, 2021 21:10
-
-
Save matsuken92/3b945f3ea4d07e9dcc0a to your computer and use it in GitHub Desktop.
[Python] Autoencoder with chainer [Relu, 1000units, Dropout:activate]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# these code on this page are based on the following chainer's example. Thanks! | |
# https://github.com/pfnet/chainer/tree/master/examples/mnist/train_mnist.py | |
%matplotlib inline | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.datasets import fetch_mldata | |
from chainer import cuda, Variable, FunctionSet, optimizers | |
import chainer.functions as F | |
import sys, time, math | |
plt.style.use('ggplot') | |
# draw a image of handwriting number | |
def draw_digit_ae(data, n, row, col, _type): | |
size = 28 | |
plt.subplot(row, col, n) | |
Z = data.reshape(size,size) # convert from vector to 28x28 matrix | |
Z = Z[::-1,:] # flip vertical | |
plt.xlim(0,28) | |
plt.ylim(0,28) | |
plt.pcolor(Z) | |
plt.title("type=%s"%(_type), size=8) | |
plt.gray() | |
plt.tick_params(labelbottom="off") | |
plt.tick_params(labelleft="off") | |
# 確率的勾配降下法で学習させる際の1回分のバッチサイズ | |
batchsize = 100 | |
# 学習の繰り返し回数 | |
n_epoch = 30 | |
# 中間層の数 | |
n_units = 1000 | |
# ノイズ付加有無 | |
noised = False | |
# MNISTの手書き数字データのダウンロード | |
# #HOME/scikit_learn_data/mldata/mnist-original.mat にキャッシュされる | |
print 'fetch MNIST dataset' | |
mnist = fetch_mldata('MNIST original') | |
# mnist.data : 70,000件の784次元ベクトルデータ | |
mnist.data = mnist.data.astype(np.float32) | |
mnist.data /= 255 # 0-1のデータに変換 | |
# mnist.target : 正解データ(教師データ) | |
mnist.target = mnist.target.astype(np.int32) | |
# 学習用データを N個、検証用データを残りの個数と設定 | |
N = 60000 | |
y_train, y_test = np.split(mnist.data.copy(), [N]) | |
N_test = y_test.shape[0] | |
if noised: | |
# Add noise | |
noise_ratio = 0.2 | |
for data in mnist.data: | |
perm = np.random.permutation(mnist.data.shape[1])[:int(mnist.data.shape[1]*noise_ratio)] | |
data[perm] = 0.0 | |
x_train, x_test = np.split(mnist.data, [N]) | |
# AutoEncoderのモデルの設定 | |
# 入力 784次元、出力 784次元, 2層 | |
model = FunctionSet(l1=F.Linear(784, n_units), | |
l2=F.Linear(n_units, 784)) | |
# Neural net architecture | |
def forward(x_data, y_data, train=True): | |
x, t = Variable(x_data), Variable(y_data) | |
y = F.dropout(F.relu(model.l1(x)), train=train) | |
x_hat = F.dropout(model.l2(y), train=train) | |
# 誤差関数として二乗誤差関数を用いる | |
return F.mean_squared_error(x_hat, t) | |
# Setup optimizer | |
optimizer = optimizers.Adam() | |
optimizer.setup(model.collect_parameters()) | |
l1_W = [] | |
l2_W = [] | |
l1_b = [] | |
l2_b = [] | |
train_loss = [] | |
test_loss = [] | |
test_mean_loss = [] | |
prev_loss = -1 | |
loss_std = 0 | |
loss_rate = [] | |
# Learning loop | |
for epoch in xrange(1, n_epoch+1): | |
print 'epoch', epoch | |
start_time = time.clock() | |
# training | |
perm = np.random.permutation(N) | |
sum_loss = 0 | |
for i in xrange(0, N, batchsize): | |
x_batch = x_train[perm[i:i+batchsize]] | |
y_batch = y_train[perm[i:i+batchsize]] | |
optimizer.zero_grads() | |
loss = forward(x_batch, y_batch) | |
loss.backward() | |
optimizer.update() | |
train_loss.append(loss.data) | |
sum_loss += float(cuda.to_cpu(loss.data)) * batchsize | |
print '\ttrain mean loss={} '.format(sum_loss / N) | |
# evaluation | |
sum_loss = 0 | |
for i in xrange(0, N_test, batchsize): | |
x_batch = x_test[i:i+batchsize] | |
y_batch = y_test[i:i+batchsize] | |
loss = forward(x_batch, y_batch, train=False) | |
test_loss.append(loss.data) | |
sum_loss += float(cuda.to_cpu(loss.data)) * batchsize | |
loss_val = sum_loss / N_test | |
print '\ttest mean loss={}'.format(loss_val) | |
if epoch == 1: | |
loss_std = loss_val | |
loss_rate.append(100) | |
else: | |
print '\tratio :%.3f'%(loss_val/loss_std * 100) | |
loss_rate.append(loss_val/loss_std * 100) | |
if prev_loss >= 0: | |
diff = loss_val - prev_loss | |
ratio = diff/prev_loss * 100 | |
print '\timpr rate:%.3f'%(-ratio) | |
prev_loss = sum_loss / N_test | |
test_mean_loss.append(loss_val) | |
l1_W.append(model.l1.W) | |
l2_W.append(model.l2.W) | |
end_time = time.clock() | |
print "\ttime = %.3f" %(end_time-start_time) | |
# Draw mean loss graph | |
plt.style.use('ggplot') | |
plt.figure(figsize=(10,7)) | |
plt.plot(test_mean_loss, lw=1) | |
plt.title("") | |
plt.ylabel("mean loss") | |
plt.show() | |
plt.xlabel("epoch") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 入力と出力を可視化 | |
plt.style.use('fivethirtyeight') | |
plt.figure(figsize=(15,25)) | |
num = 100 | |
cnt = 0 | |
ans_list = [] | |
pred_list = [] | |
for idx in np.random.permutation(N_test)[:num]: | |
xxx = x_test[idx].astype(np.float32) | |
h1 = F.dropout(F.relu(model.l1(Variable(xxx.reshape(1,784)))), train=False) | |
y = model.l2(h1) | |
cnt+=1 | |
ans_list.append(x_test[idx]) | |
pred_list.append(y) | |
cnt = 0 | |
for i in range(int(num/10)): | |
for j in range (10): | |
img_no = i*10+j | |
pos = (2*i)*10+j | |
draw_digit_ae(ans_list[img_no], pos+1, 20, 10, "ans") | |
for j in range (10): | |
img_no = i*10+j | |
pos = (2*i+1)*10+j | |
draw_digit_ae(pred_list[i*10+j].data, pos+1, 20, 10, "pred") | |
plt.show |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# W(1)を可視化 | |
plt.style.use('fivethirtyeight') | |
# draw digit images | |
def draw_digit_w1(data, n, i, length): | |
size = 28 | |
plt.subplot(math.ceil(length/15), 15, n) | |
Z = data.reshape(size,size) # convert from vector to 28x28 matrix | |
Z = Z[::-1,:] # flip vertical | |
plt.xlim(0,size) | |
plt.ylim(0,size) | |
plt.pcolor(Z) | |
plt.title("%d"%i, size=9) | |
plt.gray() | |
plt.tick_params(labelbottom="off") | |
plt.tick_params(labelleft="off") | |
plt.figure(figsize=(15,70)) | |
cnt = 1 | |
for i in range(len(l1_W[9])): | |
draw_digit_w1(l1_W[9][i], cnt, i, len(l1_W[9][i])) | |
cnt += 1 | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# W(2).Tを可視化 | |
plt.style.use('fivethirtyeight') | |
# draw digit images | |
def draw_digit2(data, i, length): | |
size = 28 | |
plt.subplot(math.ceil(length/15)+1, 15, i+1) | |
Z = data.reshape(size,size) # convert from vector to 28x28 matrix | |
Z = Z[::-1,:] # flip vertical | |
plt.xlim(0,27) | |
plt.ylim(0,27) | |
plt.pcolor(Z) | |
plt.title("%d"%i, size=9) | |
plt.gray() | |
plt.tick_params(labelbottom="off") | |
plt.tick_params(labelleft="off") | |
W_T = np.array(l2_W[9]).T | |
plt.figure(figsize=(15,30)) | |
for i in range(W_T.shape[0]): | |
draw_digit2(W_T[i], i, W_T.shape[0]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment