Skip to content

Instantly share code, notes, and snippets.

@butsugiri
Created May 17, 2016 00:18
Show Gist options
  • Save butsugiri/dbd5d55e2d09e4a3a43633b46498d83a to your computer and use it in GitHub Desktop.
Save butsugiri/dbd5d55e2d09e4a3a43633b46498d83a to your computer and use it in GitHub Desktop.
チュートリアルから作った多層パーセプトロン
# -*- coding: utf-8 -*-
"""
お手製のニューラルネットワーク
トレーニング編
"""
import sys
import numpy as np
from collections import defaultdict
def create_features(line, ids):
phi = [0] * len(ids)
words = line.rstrip().split()
for word in words:
phi[ids["UNI:"+word]] += 1
return phi
def word2ids(fi):
ids = defaultdict(lambda: len(ids))
for line in fi:
label,title = line.rstrip().split("\t")
words = title.split()
for word in words:
ids["UNI:"+word]
return ids
def forward_nn(net,phi0):
phi_list = [phi0,"*","*"] #各層の値が入る
for i, layer in enumerate(net):
w = layer["w"]
b = layer["b"]
phi_list[i+1] = np.tanh(np.dot(w,phi_list[i])+b)
# print phi_list
return phi_list
def backward_nn(net,phi, true_label):
J = len(net)
delta = [0,0,np.array([true_label-phi[J][0]])]
delta_prime = [0] * (J+1)
for i in range(J)[::-1]:
delta_prime[i+1] = delta[i+1] * (1 - phi[i+1]**2)
w = net[i]["w"]
b = net[i]["b"]
delta[i] = np.dot(delta_prime[i+1],w)
return delta_prime
def update_weights(net,phi,delta,_lambda):
for i in range(len(net)):
w = net[i]["w"]
b = net[i]["b"]
w += _lambda * np.outer(delta[i+1], phi[i])
b += _lambda * delta[i+1]
if __name__ == "__main__":
with open("./titles-en-train.labeled", "r") as fi:
feat_lab = []
#データ全体をなめて,素性のインデックスを出す
#素性のインデックス(全体の長さ)が必須であるため.
ids = word2ids(fi)
fi.seek(0)
#ここからfeature,labelのセットのリストを作る
for line in fi:
label, title = line.rstrip().split("\t")
feat_lab.append((create_features(title, ids), float(label)))
#ネットワークをランダムな値で初期化
#[0層,1層,...]
#-0.5から0.5の間で初期化した
net = [{"w":np.random.rand(2,27190)-0.5, "b":np.random.rand(2)-0.5},\
{"w":np.random.rand(1,2)-0.5, "b":np.random.rand(1)-0.5}]
print net
#学習を行う
for phi,label in feat_lab:
phi_list = forward_nn(net, phi)
delta_prime = backward_nn(net,phi_list,label)
update_weights(net,phi,delta_prime, 0.1)
print net
print ids
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment