Last active
April 23, 2020 01:53
-
-
Save abap34/a4a96d8abd052b7f86ee0503fb6bd774 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import weakref\n", | |
"import contextlib" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def as_array(x):\n", | |
" if np.isscalar(x):\n", | |
" return np.array(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def as_tuple(x):\n", | |
" if not isinstance(x,tuple):\n", | |
" return (x,)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def as_variable(obj):\n", | |
" if isinstance(obj,Variable):\n", | |
" return obj\n", | |
" return Variable(obj)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Config:\n", | |
" enable_backprop = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@contextlib.contextmanager\n", | |
"def using_config(name,value):\n", | |
" old_value = getattr(Config,name)\n", | |
" setattr(Config,name,value)\n", | |
" try:\n", | |
" yield\n", | |
" finally:\n", | |
" setattr(Config,name,old_value)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def no_grad():\n", | |
" return using_config('enable_backprop',False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Variable:\n", | |
" __array_priotity__ = 200 # 演算子の優先度\n", | |
" def __init__(self,data,name=None):\n", | |
" if data is not None:\n", | |
" if not isinstance(data,np.ndarray):\n", | |
" raise TypeError ('type:',type(data),'is not supported')\n", | |
" self.data = data # 変数の値の格納場所\n", | |
" self.name = name\n", | |
" self.grad = None\n", | |
" self.creator = None # 変数の生成元の関数を保持しておく\n", | |
" self.generation = 0\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.data)\n", | |
" \n", | |
" def __repr__(self):\n", | |
" if self.data is None:\n", | |
" return 'variable(None)'\n", | |
" p = str(self.data).replace('\\n', '\\n' + ' ' * 9)\n", | |
" return 'variable(' + p + ')'\n", | |
" \n", | |
"\n", | |
" @property\n", | |
" def shape(self):\n", | |
" return self.data.shape\n", | |
" \n", | |
" @property\n", | |
" def size(self):\n", | |
" return self.data.size\n", | |
" \n", | |
" @property\n", | |
" def dtype(self):\n", | |
" return self.data.dtype\n", | |
" \n", | |
" def set_creator(self,func):\n", | |
" self.creator = func\n", | |
" self.generation = func.generation + 1 # 優先度\n", | |
" \n", | |
" def clearglad(self):\n", | |
" self.grad = None\n", | |
" \n", | |
" def backward(self,retaion_grad=True):\n", | |
" if self.grad is None:\n", | |
" self.grad = np.ones_like(self.data)\n", | |
" #print('grad is None. create',self.grad)\n", | |
" funcs = []\n", | |
" seen = set()\n", | |
" def add_func(f):\n", | |
" if f not in seen:\n", | |
" funcs.append(f)\n", | |
" seen.add(f)\n", | |
" funcs.sort(key=lambda x : x.generation)\n", | |
" add_func(self.creator)\n", | |
" while funcs:\n", | |
" f = funcs.pop()\n", | |
" gys = [output().grad for output in f.outputs]\n", | |
" gxs = f.backward(*gys)\n", | |
" gxs = as_tuple(gxs)\n", | |
" print(f.name +\".backward(\",gys,\") = \",gxs)\n", | |
" #if f.name == \"Add\":\n", | |
" #print(f.name + '\\'s input is :',f.inputs[0].data ,\"and\",f.inputs[1].data)\n", | |
" #else:\n", | |
" #print(f.name + '\\'s input is :',f.inputs[0].data)\n", | |
" for x,gx in zip(f.inputs,gxs):\n", | |
" if x.grad is None:\n", | |
" x.grad = gx\n", | |
" else:\n", | |
" x.grad = x.grad + gx\n", | |
" if x.creator is not None: # (x.cretor is None == x が入力変数) なのでそこでストップする\n", | |
" add_func(x.creator)\n", | |
" if not retaion_grad:\n", | |
" for y in f.outputs:\n", | |
" y().grad = None # メモリを解放,yは弱参照なのでアクセスには()が必要" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Function:\n", | |
" def __call__(self,*inputs:Variable) -> Variable:\n", | |
" inputs = [as_variable(x) for x in inputs]\n", | |
" xs = [x.data for x in inputs]\n", | |
" ys = as_tuple(self.forward(*xs))\n", | |
" outputs = [Variable(as_array(y)) for y in ys]\n", | |
" if Config.enable_backprop:\n", | |
" self.generation = max([x.generation for x in inputs]) # 入力変数のうち最大の世代数を採用する\n", | |
" for output in outputs:\n", | |
" output.set_creator(self)\n", | |
" self.inputs = inputs # 入力変数を保持しておく\n", | |
" self.outputs = [weakref.ref(output) for output in outputs] # 弱参照\n", | |
" return outputs if len(outputs) > 1 else outputs[0] \n", | |
" \n", | |
" def forward(self,xs):\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" def backward(self,gys):\n", | |
" raise NotImplementedError()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Add(Function):\n", | |
" name = \"Add\"\n", | |
" def forward(self,x0,x1):\n", | |
" return x0 + x1\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" return gy,gy\n", | |
"\n", | |
"def add(x0,x1):\n", | |
" x1 = as_array(x1)\n", | |
" return Add()(x0,x1)\n", | |
"\n", | |
"Variable.__add__ = add\n", | |
"Variable.__radd__ = add" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Sub(Function):\n", | |
" name = \"Sub\"\n", | |
" def forward(self,x0,x1):\n", | |
" return x0 - x1\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" return gy, -gy\n", | |
" \n", | |
"def sub(x0,x1):\n", | |
" x1 = as_array(x1)\n", | |
" return Sub()(x0,x1)\n", | |
"\n", | |
"\n", | |
"def rsub(x0,x1): \n", | |
" x1 = as_array(x1)\n", | |
" return Sub()(x1,x0)\n", | |
"\n", | |
"Variable.__sub__ = sub\n", | |
"Variable.__rsub__ = rsub" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Neg(Function):\n", | |
" name = \"Neg\"\n", | |
" def forward(self,x):\n", | |
" return -x\n", | |
" def backward(self,gy):\n", | |
" return -gy\n", | |
"\n", | |
"def neg(x):\n", | |
" return Neg()(x)\n", | |
"\n", | |
"Variable.__neg__ = neg" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Mul(Function):\n", | |
" name = \"Mul\"\n", | |
" def forward(self,x0,x1):\n", | |
" return x0 * x1\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" x0, x1 = self.inputs[0].data, self.inputs[1].data\n", | |
" return gy * x1, gy * x0\n", | |
" \n", | |
"def mul(x0,x1):\n", | |
" x1 = as_array(x1)\n", | |
" return Mul()(x0,x1)\n", | |
"\n", | |
"\n", | |
"Variable.__mul__ = mul\n", | |
"Variable.__rmul__ = mul" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Div(Function):\n", | |
" name = \"Div\"\n", | |
" def forward(self,x0,x1):\n", | |
" return x0 / x1\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" x0, x1 = self.inputs[0].data, self.inputs[1].data\n", | |
" gx0 = gy / x1\n", | |
" gx1 = gy * (-x0 / x1 ** 2)\n", | |
" return gx0,gx1\n", | |
" \n", | |
" \n", | |
"def div(x0, x1):\n", | |
" x1 = as_array(x1)\n", | |
" return Div()(x0, x1)\n", | |
"\n", | |
"def rdiv(x0, x1):\n", | |
" x1 = as_array(x1)\n", | |
" return Div(x1,x0)\n", | |
"\n", | |
"Variable.__truediv__ = div\n", | |
"Variable.__rtruediv__ = rdiv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Pow(Function):\n", | |
" name = \"Pow\"\n", | |
" def __init__(self,c):\n", | |
" self.c = c\n", | |
" \n", | |
" def forward(self,x):\n", | |
" return x ** self.c\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" x = self.inputs[0].data\n", | |
" c = self.c\n", | |
" gx = c * x ** (c - 1) * gy\n", | |
" return gx\n", | |
" \n", | |
"def pow(x,c):\n", | |
" return Pow(c)(x)\n", | |
"\n", | |
"Variable.__pow__ = pow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class EXP(Function):\n", | |
" name = \"EXP\"\n", | |
" def forward(self,x):\n", | |
" return np.exp(x)\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" return np.exp(self.inputs[0].data) * gy\n", | |
" \n", | |
" \n", | |
"def exp(x):\n", | |
" return exp()(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Square(Function):\n", | |
" name = \"Square\"\n", | |
" def forward(self,x):\n", | |
" return x ** 2\n", | |
" \n", | |
" def backward(self,gy):\n", | |
" return 2 * self.inputs[0].data * gy \n", | |
"\n", | |
" \n", | |
"def square(x):\n", | |
" return Square()(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Add.backward( [array(1.)] ) = (array(1.), array(1.))\n", | |
"Square.backward( [array(1.)] ) = (6.0,)\n", | |
"Square.backward( [array(1.)] ) = (4.0,)\n" | |
] | |
} | |
], | |
"source": [ | |
"x = Variable(np.array(2.0))\n", | |
"y = Variable(np.array(3.0))\n", | |
"z = add(square(x),square(y))\n", | |
"z.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.0\n", | |
"6.0\n", | |
"13.0\n" | |
] | |
} | |
], | |
"source": [ | |
"print(x.grad)\n", | |
"print(y.grad)\n", | |
"print(z.data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Add.backward( [array(1.)] ) = (array(1.), array(1.))\n", | |
"Square.backward( [array(1.)] ) = (8.0,)\n", | |
"Square.backward( [array(1.)] ) = (8.0,)\n", | |
"Square.backward( [16.0] ) = (64.0,)\n" | |
] | |
} | |
], | |
"source": [ | |
"x = Variable(np.array(2.0))\n", | |
"a = square(x)\n", | |
"y = add(square(a),square(a))\n", | |
"y.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"32.0\n", | |
"64.0\n" | |
] | |
} | |
], | |
"source": [ | |
"print(y.data)\n", | |
"print(x.grad)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with no_grad():\n", | |
" x = Variable(np.array(2.0))\n", | |
" y = square(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = Variable(np.array(2.0))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable(5.0)" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x + 3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable(-2.0)" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"-x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable(0.0)" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"2 - x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable(4.0)" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"2 * x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable(1.0)" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x / 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"variable(4.0)" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def sphere(x,y):\n", | |
" return x ** 2 + y ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Add.backward( [array(1.)] ) = (array(1.), array(1.))\n", | |
"Pow.backward( [array(1.)] ) = (2.0,)\n", | |
"Pow.backward( [array(1.)] ) = (2.0,)\n" | |
] | |
} | |
], | |
"source": [ | |
"x = Variable(np.array(1.0))\n", | |
"y = Variable(np.array(1.0))\n", | |
"z = sphere(x,y)\n", | |
"z.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2.0" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x.grad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"state": {}, | |
"version_major": 2, | |
"version_minor": 0 | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment