Skip to content

Instantly share code, notes, and snippets.

@abap34
Last active April 23, 2020 01:53
Show Gist options
  • Save abap34/a4a96d8abd052b7f86ee0503fb6bd774 to your computer and use it in GitHub Desktop.
Save abap34/a4a96d8abd052b7f86ee0503fb6bd774 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import weakref\n",
"import contextlib"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def as_array(x):\n",
" if np.isscalar(x):\n",
" return np.array(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def as_tuple(x):\n",
" if not isinstance(x,tuple):\n",
" return (x,)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def as_variable(obj):\n",
" if isinstance(obj,Variable):\n",
" return obj\n",
" return Variable(obj)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Config:\n",
" enable_backprop = True"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"@contextlib.contextmanager\n",
"def using_config(name,value):\n",
" old_value = getattr(Config,name)\n",
" setattr(Config,name,value)\n",
" try:\n",
" yield\n",
" finally:\n",
" setattr(Config,name,old_value)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def no_grad():\n",
" return using_config('enable_backprop',False)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class Variable:\n",
" __array_priotity__ = 200 # 演算子の優先度\n",
" def __init__(self,data,name=None):\n",
" if data is not None:\n",
" if not isinstance(data,np.ndarray):\n",
" raise TypeError ('type:',type(data),'is not supported')\n",
" self.data = data # 変数の値の格納場所\n",
" self.name = name\n",
" self.grad = None\n",
" self.creator = None # 変数の生成元の関数を保持しておく\n",
" self.generation = 0\n",
" \n",
" def __len__(self):\n",
" return len(self.data)\n",
" \n",
" def __repr__(self):\n",
" if self.data is None:\n",
" return 'variable(None)'\n",
" p = str(self.data).replace('\\n', '\\n' + ' ' * 9)\n",
" return 'variable(' + p + ')'\n",
" \n",
"\n",
" @property\n",
" def shape(self):\n",
" return self.data.shape\n",
" \n",
" @property\n",
" def size(self):\n",
" return self.data.size\n",
" \n",
" @property\n",
" def dtype(self):\n",
" return self.data.dtype\n",
" \n",
" def set_creator(self,func):\n",
" self.creator = func\n",
" self.generation = func.generation + 1 # 優先度\n",
" \n",
" def clearglad(self):\n",
" self.grad = None\n",
" \n",
" def backward(self,retaion_grad=True):\n",
" if self.grad is None:\n",
" self.grad = np.ones_like(self.data)\n",
" #print('grad is None. create',self.grad)\n",
" funcs = []\n",
" seen = set()\n",
" def add_func(f):\n",
" if f not in seen:\n",
" funcs.append(f)\n",
" seen.add(f)\n",
" funcs.sort(key=lambda x : x.generation)\n",
" add_func(self.creator)\n",
" while funcs:\n",
" f = funcs.pop()\n",
" gys = [output().grad for output in f.outputs]\n",
" gxs = f.backward(*gys)\n",
" gxs = as_tuple(gxs)\n",
" print(f.name +\".backward(\",gys,\") = \",gxs)\n",
" #if f.name == \"Add\":\n",
" #print(f.name + '\\'s input is :',f.inputs[0].data ,\"and\",f.inputs[1].data)\n",
" #else:\n",
" #print(f.name + '\\'s input is :',f.inputs[0].data)\n",
" for x,gx in zip(f.inputs,gxs):\n",
" if x.grad is None:\n",
" x.grad = gx\n",
" else:\n",
" x.grad = x.grad + gx\n",
" if x.creator is not None: # (x.cretor is None == x が入力変数) なのでそこでストップする\n",
" add_func(x.creator)\n",
" if not retaion_grad:\n",
" for y in f.outputs:\n",
" y().grad = None # メモリを解放,yは弱参照なのでアクセスには()が必要"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class Function:\n",
" def __call__(self,*inputs:Variable) -> Variable:\n",
" inputs = [as_variable(x) for x in inputs]\n",
" xs = [x.data for x in inputs]\n",
" ys = as_tuple(self.forward(*xs))\n",
" outputs = [Variable(as_array(y)) for y in ys]\n",
" if Config.enable_backprop:\n",
" self.generation = max([x.generation for x in inputs]) # 入力変数のうち最大の世代数を採用する\n",
" for output in outputs:\n",
" output.set_creator(self)\n",
" self.inputs = inputs # 入力変数を保持しておく\n",
" self.outputs = [weakref.ref(output) for output in outputs] # 弱参照\n",
" return outputs if len(outputs) > 1 else outputs[0] \n",
" \n",
" def forward(self,xs):\n",
" raise NotImplementedError()\n",
" \n",
" def backward(self,gys):\n",
" raise NotImplementedError()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class Add(Function):\n",
" name = \"Add\"\n",
" def forward(self,x0,x1):\n",
" return x0 + x1\n",
" \n",
" def backward(self,gy):\n",
" return gy,gy\n",
"\n",
"def add(x0,x1):\n",
" x1 = as_array(x1)\n",
" return Add()(x0,x1)\n",
"\n",
"Variable.__add__ = add\n",
"Variable.__radd__ = add"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class Sub(Function):\n",
" name = \"Sub\"\n",
" def forward(self,x0,x1):\n",
" return x0 - x1\n",
" \n",
" def backward(self,gy):\n",
" return gy, -gy\n",
" \n",
"def sub(x0,x1):\n",
" x1 = as_array(x1)\n",
" return Sub()(x0,x1)\n",
"\n",
"\n",
"def rsub(x0,x1): \n",
" x1 = as_array(x1)\n",
" return Sub()(x1,x0)\n",
"\n",
"Variable.__sub__ = sub\n",
"Variable.__rsub__ = rsub"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Neg(Function):\n",
" name = \"Neg\"\n",
" def forward(self,x):\n",
" return -x\n",
" def backward(self,gy):\n",
" return -gy\n",
"\n",
"def neg(x):\n",
" return Neg()(x)\n",
"\n",
"Variable.__neg__ = neg"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class Mul(Function):\n",
" name = \"Mul\"\n",
" def forward(self,x0,x1):\n",
" return x0 * x1\n",
" \n",
" def backward(self,gy):\n",
" x0, x1 = self.inputs[0].data, self.inputs[1].data\n",
" return gy * x1, gy * x0\n",
" \n",
"def mul(x0,x1):\n",
" x1 = as_array(x1)\n",
" return Mul()(x0,x1)\n",
"\n",
"\n",
"Variable.__mul__ = mul\n",
"Variable.__rmul__ = mul"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"class Div(Function):\n",
" name = \"Div\"\n",
" def forward(self,x0,x1):\n",
" return x0 / x1\n",
" \n",
" def backward(self,gy):\n",
" x0, x1 = self.inputs[0].data, self.inputs[1].data\n",
" gx0 = gy / x1\n",
" gx1 = gy * (-x0 / x1 ** 2)\n",
" return gx0,gx1\n",
" \n",
" \n",
"def div(x0, x1):\n",
" x1 = as_array(x1)\n",
" return Div()(x0, x1)\n",
"\n",
"def rdiv(x0, x1):\n",
" x1 = as_array(x1)\n",
" return Div(x1,x0)\n",
"\n",
"Variable.__truediv__ = div\n",
"Variable.__rtruediv__ = rdiv"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class Pow(Function):\n",
" name = \"Pow\"\n",
" def __init__(self,c):\n",
" self.c = c\n",
" \n",
" def forward(self,x):\n",
" return x ** self.c\n",
" \n",
" def backward(self,gy):\n",
" x = self.inputs[0].data\n",
" c = self.c\n",
" gx = c * x ** (c - 1) * gy\n",
" return gx\n",
" \n",
"def pow(x,c):\n",
" return Pow(c)(x)\n",
"\n",
"Variable.__pow__ = pow"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class EXP(Function):\n",
" name = \"EXP\"\n",
" def forward(self,x):\n",
" return np.exp(x)\n",
" \n",
" def backward(self,gy):\n",
" return np.exp(self.inputs[0].data) * gy\n",
" \n",
" \n",
"def exp(x):\n",
" return exp()(x)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class Square(Function):\n",
" name = \"Square\"\n",
" def forward(self,x):\n",
" return x ** 2\n",
" \n",
" def backward(self,gy):\n",
" return 2 * self.inputs[0].data * gy \n",
"\n",
" \n",
"def square(x):\n",
" return Square()(x)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add.backward( [array(1.)] ) = (array(1.), array(1.))\n",
"Square.backward( [array(1.)] ) = (6.0,)\n",
"Square.backward( [array(1.)] ) = (4.0,)\n"
]
}
],
"source": [
"x = Variable(np.array(2.0))\n",
"y = Variable(np.array(3.0))\n",
"z = add(square(x),square(y))\n",
"z.backward()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.0\n",
"6.0\n",
"13.0\n"
]
}
],
"source": [
"print(x.grad)\n",
"print(y.grad)\n",
"print(z.data)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add.backward( [array(1.)] ) = (array(1.), array(1.))\n",
"Square.backward( [array(1.)] ) = (8.0,)\n",
"Square.backward( [array(1.)] ) = (8.0,)\n",
"Square.backward( [16.0] ) = (64.0,)\n"
]
}
],
"source": [
"x = Variable(np.array(2.0))\n",
"a = square(x)\n",
"y = add(square(a),square(a))\n",
"y.backward()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"32.0\n",
"64.0\n"
]
}
],
"source": [
"print(y.data)\n",
"print(x.grad)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"with no_grad():\n",
" x = Variable(np.array(2.0))\n",
" y = square(x)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"x = Variable(np.array(2.0))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable(5.0)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x + 3"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable(-2.0)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-x"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable(0.0)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"2 - x"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable(4.0)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"2 * x"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable(1.0)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x / 2"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable(4.0)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x ** 2"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"def sphere(x,y):\n",
" return x ** 2 + y ** 2"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add.backward( [array(1.)] ) = (array(1.), array(1.))\n",
"Pow.backward( [array(1.)] ) = (2.0,)\n",
"Pow.backward( [array(1.)] ) = (2.0,)\n"
]
}
],
"source": [
"x = Variable(np.array(1.0))\n",
"y = Variable(np.array(1.0))\n",
"z = sphere(x,y)\n",
"z.backward()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.0"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.grad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment