Created
May 21, 2016 05:46
-
-
Save butsugiri/4cc68f4a56ee190c057d4e6267a6a695 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Chainerを用いた多層パーセプトロンの実装 | |
多分に参考にしました: https://github.com/ichiroex/xor-mlp-chainer | |
""" | |
import sys | |
import numpy as np | |
import chainer | |
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils | |
from chainer import Link, Chain, ChainList | |
import chainer.functions as F | |
import chainer.links as L | |
source = [[0, 0], [1, 0], [0, 1], [1, 1]] | |
target = [[0], [1], [1], [0]] | |
dataset = {} | |
dataset['source'] = np.array(source, dtype=np.float32) | |
dataset['target'] = np.array(target, dtype=np.float32) | |
class MyMLP(Chain): | |
def __init__(self): | |
super(MyMLP, self).__init__( | |
l1 = L.Linear(2,2), | |
l2 = L.Linear(2,1), | |
) | |
def forward(self, x, t): | |
h1 = F.sigmoid(self.l1(x)) | |
return F.sigmoid(self.l2(h1)) | |
if __name__ == "__main__": | |
model = MyMLP() | |
optimizer = optimizers.Adam() | |
optimizer.setup(model) | |
# Learning loop | |
loss_val = 100 | |
epoch = 0 | |
while loss_val > 1e-5: | |
# training | |
x = chainer.Variable(np.asarray(dataset['source'])) #source | |
t = chainer.Variable(np.asarray(dataset['target'])) #target | |
model.zerograds() # 勾配をゼロ初期化 | |
y = model.forward(x,t) | |
loss = F.mean_squared_error(y, t) | |
loss.backward() # 誤差逆伝播 | |
optimizer.update() # 最適化 | |
# 途中結果を表示 | |
if epoch % 1000 == 0: | |
#誤差と正解率を計算 | |
loss_val = loss.data | |
print 'epoch:', epoch | |
print 'x:\n', x.data | |
print 't:\n', t.data | |
print 'y:\n', y.data | |
print('train mean loss={}'.format(loss_val)) # 訓練誤差, 正解率 | |
print ' - - - - - - - - - ' | |
# n_epoch以上になると終了 | |
if epoch >= 100000: | |
break | |
epoch += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment