Created
June 25, 2011 03:24
-
-
Save zhuowater/1046096 to your computer and use it in GitHub Desktop.
python BP ANN 神经网络
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import math,random | |
import numpy as np | |
def rand(a, b): | |
return (b-a)*random.random() + a | |
def dtanh(y): | |
return 1.0-y*y | |
def mat(x): | |
return math.tanh(x) | |
def arc(x): | |
return math.atan(x) | |
#return | |
class cell(object): | |
def __init__(self): | |
self.input=[] | |
#self.output=0 | |
self.thert=0.0 | |
self.mac=mat | |
self.xgma=0.0 | |
def add_input(self,num,weight=0): | |
self.input.append({num:weight}) | |
def get_pro(self): | |
l=[x.items()[0][0]*x.items()[0][1] for x in self.input] | |
s=sum(l) | |
return s | |
def output(self): | |
return self.mac(self.get_pro()-self.thert) | |
def clean(self): | |
self.input=[] | |
output=property(output) | |
class ANN_bp: | |
def __init__(self,num_in,st_h,num_ou): | |
self.exce=[] | |
self.num_in=num_in | |
self.st_h=st_h | |
self.num_ou=num_ou | |
self.h=[] | |
self.create_hide() | |
self.create_cell_ou() | |
#self.output_w=[] | |
def create_maxt(self,m,n,fill=random.random()): | |
#A=[fill]*m | |
#A=[0.1 for i in range(n)] | |
M=[[rand(-2,2) for i in range(n)] for i in range(m)] | |
#print M | |
return np.matrix(M) | |
def create_hide(self): | |
self.hide_cell=[] | |
s_head=self.num_in | |
for i in self.st_h: | |
self.h.append(self.create_maxt(s_head,i)) | |
c=[cell() for k in range(i)] | |
s_head=i | |
self.hide_cell.append(c) | |
def create_cell_ou(self): | |
s_head=self.st_h[-1] if self.st_h else self.num_in | |
self.output_w=self.create_maxt(s_head,self.num_ou) | |
self.output_cell=[cell() for i in range(self.num_ou)] | |
def input(self,*data): | |
if len(*data)!=self.num_in:return | |
self.input=data[:] | |
def format(self): | |
for i in self.hide_cell: | |
for k in i:k.clean() | |
for i in self.output_cell: | |
i.clean() | |
def update(self): | |
self.format() | |
input_line=self.input | |
i=0 | |
for cells in self.hide_cell: | |
#print len(cells) | |
j=0 | |
hs=self.h[i] | |
#print 'deep (%s)'%i,hs | |
for ii in input_line: | |
l=0 | |
for cel in cells: | |
cel.add_input(ii,hs[j,l]) | |
l+=1 | |
j+=1 | |
input_line=[k.output for k in cells] | |
i+=1 | |
#print 'aaaaaaa',input_line | |
j=0 | |
for k in input_line: | |
i=0 | |
for c in self.output_cell: | |
c.add_input(k,self.output_w[j,i]) | |
i+=1 | |
j+=1 | |
self.output=[k.output for k in self.output_cell] | |
def cau_e(self): | |
error = 0.0 | |
for k in range(len(self.output)): | |
delta = self.exce[k]-self.output[k] | |
error = error + 0.5*delta*delta | |
return error | |
def backProate(self,N=0.5): | |
i=0 | |
for k in self.output_cell: | |
#print k | |
delta = self.exce[i]-self.output[i] | |
#print delta | |
#print k.get_pro() | |
k.xgma=dtanh(k.output)*delta | |
#print k.input,'----',k.xgma | |
#print 'ddddddddd',k.xgma | |
i+=1 | |
hide=self.hide_cell[:] | |
hide.reverse() | |
head=self.output_cell | |
#print head | |
for deep in hide: | |
j=0 | |
#print 'deep',deep | |
#print '-----input:',[l.input for l in head] | |
for x in deep: | |
#print [l.input[j].values()[0] for l in head] | |
#print j | |
#print head[0].input | |
x.xgma=sum([l.xgma*l.input[j].values()[0] for l in head])*dtanh(x.output) | |
j+=1 | |
#print x.input,'---->>>',x.xgma | |
head=deep | |
#print head[0].xgma | |
input_line=self.input | |
i=0 | |
for cells in self.hide_cell: | |
#print len(cells) | |
j=0 | |
hs=self.h[i] | |
for ii in input_line: | |
l=0 | |
for cel in cells: | |
#print j,l | |
hs[j,l]=hs[j,l]+N*cel.xgma*cel.input[j].keys()[0] | |
l+=1 | |
j+=1 | |
i+=1 | |
input_line=[k.output for k in cells] | |
#print 'aaaaaaa',input_line | |
#print self.h | |
#print input_line | |
#print self.output_w | |
j=0 | |
for k in input_line: | |
i=0 | |
for c in self.output_cell: | |
#print c.xgma,c.input[j].keys()[0] | |
change=c.xgma*c.input[j].keys()[0] | |
#print 'ccccc',change | |
self.output_w[j,i]+=change*N | |
i+=1 | |
j+=1 | |
#print self.output_w | |
def excep(self,*data): | |
if len(data)!=self.num_ou:return | |
self.exce=data | |
def train(self,patterns,iterations=2000, N=0.5): | |
for i in range(iterations): | |
error = 0.0 | |
for p in patterns: | |
self.input=p[0] | |
#print self.input | |
self.exce=p[1] | |
self.update() | |
self.backProate() | |
if i % 100 == 0: | |
print 'error', self.cau_e() | |
for p in patterns: | |
self.input=p[0] | |
#print self.input | |
self.exce=p[1] | |
self.update() | |
print p[0],'-----',self.output,'---',self.cau_e() | |
c=ANN_bp(2,[4,2,2],1) #[输入结点数[每个隐藏层结点数]输出结点数] | |
#print c.h | |
#c.h[1,2]=0 | |
#c.input(0.4,0.3) | |
#print c.hide_cell | |
#c.input(0,0.4,0.3) | |
pat = [ | |
[[2.0,0.1], [0.0]], | |
[[3.0,0.2], [0.3]], | |
[[1.0,0.2], [-0.1]], | |
[[2.0,0.2], [0.1]], | |
] | |
c.train(pat) | |
c.input=[10,0.3] | |
c.update() | |
print c.output | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment