Skip to content

Instantly share code, notes, and snippets.

@pgtwitter
Created November 8, 2015 05:09
Show Gist options
  • Save pgtwitter/2eadbd8829ae1a3501b8 to your computer and use it in GitHub Desktop.
Save pgtwitter/2eadbd8829ae1a3501b8 to your computer and use it in GitHub Desktop.
chainer で XOR の 出力 の 推移 を ffmpeg で mp4 にする (subplot有り)
#! /usr/bin/env python
#encoding: utf-8
import numpy as np
import chainer.functions as F
from chainer import FunctionSet, Variable, optimizers
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fig = plt.figure()
ims = []
err = []
T=5000
model = FunctionSet(
l1 = F.Linear(2, 2),
l2 = F.Linear(2, 1)
)
def forward(x):
return F.sigmoid(model.l2(F.sigmoid(model.l1(x))))
def calc(x_data):
x = Variable(x_data.reshape(1,2).astype(np.float32), volatile=False)
h = forward(x)
return h
def train(x_data, y_data):
h = calc(x_data)
y = Variable(y_data.reshape(1,1).astype(np.float32), volatile=False)
optimizer.zero_grads()
error = F.mean_squared_error(h, y)
error.backward()
optimizer.update()
return error.data
def show(data):
N = len(data)
for j in range(0, N):
x, t = data[j]
h = calc(x)
print "{} -> {} : {}".format(x, h.data, t)
def save():
Writer = animation.writers['ffmpeg']
writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
ani = animation.ArtistAnimation(fig, ims, interval=25, repeat_delay=1000)
#動画として保存
ani.save('im.mp4', writer=writer)
#plt.show()
def draw():
x = np.arange(0, 1.2, 0.2)
y = np.arange(0, 1.2, 0.2)
z = []
for ix in x:
yz = []
for iy in y:
h = calc(np.array([ix, iy]))
yz.append(h.data[0][0])
z.append(yz)
X, Y = np.meshgrid(x, y)
Z = np.array(z)
plt.subplot(121, aspect='equal')
ima = plt.pcolor(X, Y, Z)
plt.subplot(122)
plt.ylim(0, 1.3)
plt.xlim(0, T)
imb = plt.scatter(np.array(range(0, len(err))), np.array(err))
return [ima, imb]
#optimizer = optimizers.AdaDelta(rho=0.95, eps=1e-06)
#optimizer = optimizers.AdaGrad(lr=0.001, eps=1e-08)
#optimizer = optimizers.Adam(alpha=0.001, beta1=0.9, beta2=0.999, eps=1e-08)
#optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)
#optimizer = optimizers.NesterovAG(lr=0.01, momentum=0.9)
optimizer = optimizers.RMSprop(lr=0.01, alpha=0.99, eps=1e-08)
#optimizer = optimizers.SGD(lr=0.01)
optimizer.setup(model)
data_xor = [
[np.array([0.25, 0.25]), np.array([0])],
[np.array([0.25, 0.75]), np.array([1])],
[np.array([0.75, 0.25]), np.array([1])],
[np.array([0.75, 0.75]), np.array([0])],
]
test_xor = [
[np.array([0, 0]), np.array([0])],
[np.array([0, 1]), np.array([1])],
[np.array([1, 0]), np.array([1])],
[np.array([1, 1]), np.array([0])],
]
print "###学習前###"
show(data_xor)
ims.append(draw())
#学習
N = len(data_xor)
for k in range(0, T/10):
print "frame {}".format(k)
for i in range(0, 10):
perm = np.random.permutation(N)
s = 0;
for j in range(0, N):
x, t = data_xor[perm[j]]
s+= train(x, t)
err.append(s)
ims.append(draw())
print "###学習後###"
show(data_xor)
print "###テスト###"
show(test_xor)
save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment