Created
May 12, 2019 22:42
-
-
Save Cartman0/f9c63d94c3c73c8dc9c9c6ff9d6668d2 to your computer and use it in GitHub Desktop.
backpropで微分値を求める
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy as sp | |
class FB: | |
def __init__(self): | |
pass | |
class Sigmoid: | |
def __init__(self): | |
self._output = None | |
def forward(self, input): | |
self._output = 1 / (1 + sp.exp(- input)) | |
return self._output | |
def backward(self, diff_input=1): | |
sig = self._output | |
return sig * (1 - sig) * diff_input | |
class X_nthPower: | |
def __init__(self, nthPower = 1): | |
self._output = None | |
self._nthPower = nthPower | |
def forward(self, input): | |
self._output = input**self._nthPower | |
return self._output | |
def backward(self, diff_input=1): | |
''' | |
x^n' = n x^{n-1} | |
x^1 = (x^n)^(1/n) | |
''' | |
x = self._output ** (1/self._nthPower) | |
return self._nthPower * self._output / x * diff_input | |
class X_addCons: | |
def __init__(self, constant = 0): | |
self._output = None | |
self._constant = constant | |
def forward(self, input): | |
''' | |
X+c | |
''' | |
self._output = input + self._constant | |
return self._output | |
def backward(self, diff_input=1): | |
return 1 * diff_input | |
class X_scaleCons: | |
def __init__(self, constant = 1): | |
self._output = None | |
self._constant = constant | |
def forward(self, input): | |
''' | |
c x | |
''' | |
self._output = self._constant * input | |
return self._output | |
def backward(self, diff_input=1): | |
return self._constant * diff_input | |
class Exp: | |
def __init__(self): | |
self._output = None | |
def forward(self, input): | |
''' | |
exp(x) | |
''' | |
self._output = sp.exp(input) | |
return self._output | |
def backward(self, diff_input=1): | |
return self._output * diff_input | |
class _Network: | |
def __init__(self, units): | |
''' | |
units: list | |
''' | |
self._units = units | |
def forward(self,input): | |
x = input | |
for u in self._units: | |
fw = u.forward(x) | |
x = fw | |
return fw | |
def backward(self): | |
diff_x = 1 | |
for u in self._units[::-1]: | |
bw = u.backward(diff_x) | |
diff_x = bw | |
return bw | |
def createNetwork(self, units): | |
return self._Network(units) | |
fb = FB() | |
sig = fb.Sigmoid() | |
sig.forward(1) | |
sig.backward() | |
x = 1 | |
fb_xScaleMinus1 = fb.X_scaleCons(-1) | |
fw1 = fb_xScaleMinus1.forward(1) | |
fw1 | |
fb_exp = fb.Exp() | |
fw2 = fb_exp.forward(fw1) | |
fw2 | |
fb_xadd1 = fb.X_addCons(constant=1) | |
fw3 = fb_xadd1.forward(fw2) | |
fw3 | |
fb_xpowerMinus1 = fb.X_nthPower(nthPower=-1) | |
fb_xpowerMinus1.forward(fw3) | |
# backward | |
bw1 = fb_xpowerMinus1.backward(1) | |
bw1 | |
bw2 = fb_xadd1.backward(bw1) | |
bw2 | |
bw3 = fb_exp.backward(bw2) | |
bw3 | |
bw4 = fb_xScaleMinus1.backward(bw3) | |
bw4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment