Skip to content

Instantly share code, notes, and snippets.

@pgtwitter
Created November 7, 2015 20:17
Show Gist options
  • Save pgtwitter/cb18c50b531f4f9daea1 to your computer and use it in GitHub Desktop.
Save pgtwitter/cb18c50b531f4f9daea1 to your computer and use it in GitHub Desktop.
chainer で XOR
#! /usr/bin/env python
#encoding: utf-8
import numpy as np
import chainer.functions as F
from chainer import FunctionSet, Variable, optimizers
model= FunctionSet(
l1 = F.Linear(2, 2),
l2 = F.Linear(2, 1)
)
def forward(x):
return F.sigmoid(model.l2(F.sigmoid(model.l1(x))))
def calc(x_data):
x = Variable(x_data.reshape(1,2).astype(np.float32), volatile=False)
h = forward(x)
return h
def train(x_data, y_data):
h = calc(x_data)
y = Variable(y_data.reshape(1,1).astype(np.float32), volatile=False)
optimizer.zero_grads()
error = F.mean_squared_error(h, y)
error.backward()
optimizer.update()
return error.data
#optimizer = optimizers.AdaDelta(rho=0.95, eps=1e-06)
#optimizer = optimizers.AdaGrad(lr=0.001, eps=1e-08)
#optimizer = optimizers.Adam(alpha=0.001, beta1=0.9, beta2=0.999, eps=1e-08)
#optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)
#optimizer = optimizers.NesterovAG(lr=0.01, momentum=0.9)
optimizer = optimizers.RMSprop(lr=0.01, alpha=0.99, eps=1e-08)
#optimizer = optimizers.SGD(lr=0.01)
optimizer.setup(model)
data_xor = [
[np.array([0.25, 0.25]), np.array([0])],
[np.array([0.25, 0.75]), np.array([1])],
[np.array([0.75, 0.25]), np.array([1])],
[np.array([0.75, 0.75]), np.array([0])],
]
N = len(data_xor)
print "###学習前###"
for j in range(0, N):
x, t= data_xor[j]
h = calc(x)
print "{} -> {} : {}".format(x, h.data, t)
#学習
err= []
for i in range(0, 5000):
perm = np.random.permutation(N)
s= 0;
for j in range(0, N):
x, t= data_xor[perm[j]]
s+= train(x, t)
err.append(s)
print "###学習後###"
for j in range(0, N):
x, t= data_xor[j]
h = calc(x)
print "{} -> {} : {}".format(x, h.data, t)
print "###テスト###"
test_xor = [
[np.array([0, 0]), np.array([0])],
[np.array([0, 1]), np.array([1])],
[np.array([1, 0]), np.array([1])],
[np.array([1, 1]), np.array([0])],
]
for j in range(0, len(test_xor)):
x, t= test_xor[j]
h = calc(x)
print "{} -> {} : {}".format(x, h.data, t)
#誤差の推移
if (True):
import matplotlib.pyplot as plt
plt.plot(err)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment