Created
September 29, 2015 08:03
-
-
Save canard0328/be986d937e850d02e91c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import chainer.functions as F\n", | |
"from chainer import Variable, FunctionSet, optimizers, function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X = np.array([\n", | |
" [1, 3, 0.001],\n", | |
" [2, 0.001, 5],\n", | |
" [0.3, 1, 0.2],\n", | |
" [4, 10, 1],\n", | |
" [5, 9, 5],\n", | |
" [0.1, 0.3, 0.9],\n", | |
" [10, 1, 0.1],\n", | |
" [2, 10, 10],\n", | |
" [0.2, 0.5, 11],\n", | |
" [9, 0.1, 9],\n", | |
" [0.9, 11, 0.1]\n", | |
" ]).astype(np.float32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"y = 2*X[:,0]**2*X[:,1] - X[:,1]**0.5*X[:,2]**0.5 + np.random.normal(scale=0.1,size=X.shape[0]).astype(np.float32)\n", | |
"y = y.reshape(len(y), 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"X = np.log(X).astype(np.float32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"num_units = 2\n", | |
"model = FunctionSet(\n", | |
" l1 = F.Linear(3, num_units),\n", | |
" l2 = F.Linear(num_units, 1)\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"class Expo(function.Function):\n", | |
" def forward_cpu(self, x):\n", | |
" self.y = np.exp(x[0])\n", | |
" return self.y,\n", | |
"\n", | |
" def backward_cpu(self, x, gy):\n", | |
" return gy[0] * self.y,\n", | |
"\n", | |
"def expo(x):\n", | |
" return Expo()(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def forward(x_data, y_data):\n", | |
" x, t = Variable(x_data), Variable(y_data)\n", | |
" h1 = expo(model.l1(x))\n", | |
" y = model.l2(h1)\n", | |
" return F.mean_squared_error(y, t)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"optimizer = optimizers.Adam()\n", | |
"optimizer.setup(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch:100 train mean loss=31372.7792969\n", | |
"epoch:200 train mean loss=31234.3144531\n", | |
"epoch:300 train mean loss=31152.5253906\n", | |
"epoch:400 train mean loss=31094.3867188\n", | |
"epoch:500 train mean loss=31048.0175781\n", | |
"epoch:600 train mean loss=31007.4199219\n", | |
"epoch:700 train mean loss=30967.8574219\n", | |
"epoch:800 train mean loss=30920.8417969\n", | |
"epoch:900 train mean loss=30803.2382812\n", | |
"epoch:1000 train mean loss=28303.3183594\n", | |
"epoch:1100 train mean loss=14691.9912109\n", | |
"epoch:1200 train mean loss=9743.24511719\n", | |
"epoch:1300 train mean loss=5964.33398438\n", | |
"epoch:1400 train mean loss=2970.22753906\n", | |
"epoch:1500 train mean loss=974.683654785\n", | |
"epoch:1600 train mean loss=357.173095703\n", | |
"epoch:1700 train mean loss=278.009765625\n", | |
"epoch:1800 train mean loss=253.25994873\n", | |
"epoch:1900 train mean loss=235.991958618\n", | |
"epoch:2000 train mean loss=221.130355835\n", | |
"epoch:2100 train mean loss=207.482162476\n", | |
"epoch:2200 train mean loss=194.632583618\n", | |
"epoch:2300 train mean loss=182.404312134\n", | |
"epoch:2400 train mean loss=170.71421814\n", | |
"epoch:2500 train mean loss=159.516830444\n", | |
"epoch:2600 train mean loss=148.791152954\n", | |
"epoch:2700 train mean loss=138.524398804\n", | |
"epoch:2800 train mean loss=128.708847046\n", | |
"epoch:2900 train mean loss=119.339553833\n", | |
"epoch:3000 train mean loss=110.41217804\n", | |
"epoch:3100 train mean loss=101.923614502\n", | |
"epoch:3200 train mean loss=93.8698654175\n", | |
"epoch:3300 train mean loss=86.245262146\n", | |
"epoch:3400 train mean loss=79.0436172485\n", | |
"epoch:3500 train mean loss=72.2586212158\n", | |
"epoch:3600 train mean loss=65.8811569214\n", | |
"epoch:3700 train mean loss=59.9036750793\n", | |
"epoch:3800 train mean loss=54.3155403137\n", | |
"epoch:3900 train mean loss=49.1061859131\n", | |
"epoch:4000 train mean loss=44.2637748718\n", | |
"epoch:4100 train mean loss=39.7768783569\n", | |
"epoch:4200 train mean loss=35.6319160461\n", | |
"epoch:4300 train mean loss=31.815990448\n", | |
"epoch:4400 train mean loss=28.3146286011\n", | |
"epoch:4500 train mean loss=25.1136341095\n", | |
"epoch:4600 train mean loss=22.1983680725\n", | |
"epoch:4700 train mean loss=19.5539932251\n", | |
"epoch:4800 train mean loss=17.1654510498\n", | |
"epoch:4900 train mean loss=15.0172872543\n", | |
"epoch:5000 train mean loss=13.0945644379\n", | |
"epoch:5100 train mean loss=11.3816661835\n", | |
"epoch:5200 train mean loss=9.86406230927\n", | |
"epoch:5300 train mean loss=8.52664470673\n", | |
"epoch:5400 train mean loss=7.3549823761\n", | |
"epoch:5500 train mean loss=6.3347454071\n", | |
"epoch:5600 train mean loss=5.45214128494\n", | |
"epoch:5700 train mean loss=4.69383335114\n", | |
"epoch:5800 train mean loss=4.04701328278\n", | |
"epoch:5900 train mean loss=3.4996817112\n", | |
"epoch:6000 train mean loss=3.0401597023\n", | |
"epoch:6100 train mean loss=2.6576256752\n", | |
"epoch:6200 train mean loss=2.34209609032\n", | |
"epoch:6300 train mean loss=2.08418512344\n", | |
"epoch:6400 train mean loss=1.87544584274\n", | |
"epoch:6500 train mean loss=1.70815765858\n", | |
"epoch:6600 train mean loss=1.57541358471\n", | |
"epoch:6700 train mean loss=1.47113156319\n", | |
"epoch:6800 train mean loss=1.389950037\n", | |
"epoch:6900 train mean loss=1.32728886604\n", | |
"epoch:7000 train mean loss=1.27922046185\n", | |
"epoch:7100 train mean loss=1.24244046211\n", | |
"epoch:7200 train mean loss=1.21423864365\n", | |
"epoch:7300 train mean loss=1.19241094589\n", | |
"epoch:7400 train mean loss=1.17518079281\n", | |
"epoch:7500 train mean loss=1.16118133068\n", | |
"epoch:7600 train mean loss=1.1493421793\n", | |
"epoch:7700 train mean loss=1.138854146\n", | |
"epoch:7800 train mean loss=1.12911236286\n", | |
"epoch:7900 train mean loss=1.11967742443\n", | |
"epoch:8000 train mean loss=1.11018323898\n", | |
"epoch:8100 train mean loss=1.10039067268\n", | |
"epoch:8200 train mean loss=1.09007024765\n", | |
"epoch:8300 train mean loss=1.07904744148\n", | |
"epoch:8400 train mean loss=1.06715524197\n", | |
"epoch:8500 train mean loss=1.05417370796\n", | |
"epoch:8600 train mean loss=1.03992474079\n", | |
"epoch:8700 train mean loss=1.02412843704\n", | |
"epoch:8800 train mean loss=1.006529212\n", | |
"epoch:8900 train mean loss=0.986808001995\n", | |
"epoch:9000 train mean loss=0.964643120766\n", | |
"epoch:9100 train mean loss=0.939798891544\n", | |
"epoch:9200 train mean loss=0.912258565426\n", | |
"epoch:9300 train mean loss=0.882494330406\n", | |
"epoch:9400 train mean loss=0.851729571819\n", | |
"epoch:9500 train mean loss=0.822065412998\n", | |
"epoch:9600 train mean loss=0.795871198177\n", | |
"epoch:9700 train mean loss=0.77457010746\n", | |
"epoch:9800 train mean loss=0.757784485817\n", | |
"epoch:9900 train mean loss=0.744022369385\n", | |
"epoch:10000 train mean loss=0.731878876686\n", | |
"epoch:10100 train mean loss=0.720541238785\n", | |
"epoch:10200 train mean loss=0.709630548954\n", | |
"epoch:10300 train mean loss=0.698968350887\n", | |
"epoch:10400 train mean loss=0.688422620296\n", | |
"epoch:10500 train mean loss=0.677916944027\n", | |
"epoch:10600 train mean loss=0.667325556278\n", | |
"epoch:10700 train mean loss=0.656499803066\n", | |
"epoch:10800 train mean loss=0.645299911499\n", | |
"epoch:10900 train mean loss=0.633540391922\n", | |
"epoch:11000 train mean loss=0.621025323868\n", | |
"epoch:11100 train mean loss=0.607669651508\n", | |
"epoch:11200 train mean loss=0.593504369259\n", | |
"epoch:11300 train mean loss=0.578767538071\n", | |
"epoch:11400 train mean loss=0.564007997513\n", | |
"epoch:11500 train mean loss=0.549864828587\n", | |
"epoch:11600 train mean loss=0.540891170502\n", | |
"epoch:11700 train mean loss=0.525401651859\n", | |
"epoch:11800 train mean loss=0.514214456081\n", | |
"epoch:11900 train mean loss=0.505761563778\n", | |
"epoch:12000 train mean loss=0.495530039072\n", | |
"epoch:12100 train mean loss=0.487383395433\n", | |
"epoch:12200 train mean loss=0.47992375493\n", | |
"epoch:12300 train mean loss=0.473052024841\n", | |
"epoch:12400 train mean loss=0.466601073742\n", | |
"epoch:12500 train mean loss=0.46064427495\n", | |
"epoch:12600 train mean loss=0.457746893167\n", | |
"epoch:12700 train mean loss=0.449678599834\n", | |
"epoch:12800 train mean loss=0.444524407387\n", | |
"epoch:12900 train mean loss=0.441940665245\n", | |
"epoch:13000 train mean loss=0.438120096922\n", | |
"epoch:13100 train mean loss=0.429325580597\n", | |
"epoch:13200 train mean loss=0.424200803041\n", | |
"epoch:13300 train mean loss=0.421023637056\n", | |
"epoch:13400 train mean loss=0.413643360138\n", | |
"epoch:13500 train mean loss=0.408192455769\n", | |
"epoch:13600 train mean loss=0.417691737413\n", | |
"epoch:13700 train mean loss=0.397109866142\n", | |
"epoch:13800 train mean loss=0.391566067934\n", | |
"epoch:13900 train mean loss=0.386039912701\n", | |
"epoch:14000 train mean loss=0.380708485842\n", | |
"epoch:14100 train mean loss=0.379563778639\n", | |
"epoch:14200 train mean loss=0.370126068592\n", | |
"epoch:14300 train mean loss=0.365221321583\n", | |
"epoch:14400 train mean loss=0.360549241304\n", | |
"epoch:14500 train mean loss=0.356159001589\n", | |
"epoch:14600 train mean loss=0.352070212364\n", | |
"epoch:14700 train mean loss=0.348999798298\n", | |
"epoch:14800 train mean loss=0.356187701225\n", | |
"epoch:14900 train mean loss=0.341304928064\n", | |
"epoch:15000 train mean loss=0.338270157576\n", | |
"epoch:15100 train mean loss=0.335548371077\n", | |
"epoch:15200 train mean loss=0.337527573109\n", | |
"epoch:15300 train mean loss=0.330541372299\n", | |
"epoch:15400 train mean loss=0.328361898661\n", | |
"epoch:15500 train mean loss=0.329148501158\n", | |
"epoch:15600 train mean loss=0.324667870998\n", | |
"epoch:15700 train mean loss=0.32468059659\n", | |
"epoch:15800 train mean loss=0.321730703115\n", | |
"epoch:15900 train mean loss=0.320273518562\n", | |
"epoch:16000 train mean loss=0.319922745228\n", | |
"epoch:16100 train mean loss=0.318046182394\n", | |
"epoch:16200 train mean loss=0.317025691271\n", | |
"epoch:16300 train mean loss=0.316511541605\n", | |
"epoch:16400 train mean loss=0.31578925252\n", | |
"epoch:16500 train mean loss=0.322080284357\n", | |
"epoch:16600 train mean loss=0.313970118761\n", | |
"epoch:16700 train mean loss=0.313417315483\n", | |
"epoch:16800 train mean loss=0.313607424498\n", | |
"epoch:16900 train mean loss=0.312390357256\n", | |
"epoch:17000 train mean loss=0.311956524849\n", | |
"epoch:17100 train mean loss=0.311582326889\n", | |
"epoch:17200 train mean loss=0.3112231493\n", | |
"epoch:17300 train mean loss=0.310920864344\n", | |
"epoch:17400 train mean loss=0.310798853636\n", | |
"epoch:17500 train mean loss=0.310363650322\n", | |
"epoch:17600 train mean loss=0.310130536556\n", | |
"epoch:17700 train mean loss=0.30990344286\n", | |
"epoch:17800 train mean loss=0.30972135067\n", | |
"epoch:17900 train mean loss=0.309542268515\n", | |
"epoch:18000 train mean loss=0.30946713686\n", | |
"epoch:18100 train mean loss=0.309371173382\n", | |
"epoch:18200 train mean loss=0.30907946825\n", | |
"epoch:18300 train mean loss=0.309720695019\n", | |
"epoch:18400 train mean loss=0.308821648359\n", | |
"epoch:18500 train mean loss=0.308723211288\n", | |
"epoch:18600 train mean loss=0.310669630766\n", | |
"epoch:18700 train mean loss=0.3085103333\n", | |
"epoch:18800 train mean loss=0.309707671404\n", | |
"epoch:18900 train mean loss=0.3083370924\n", | |
"epoch:19000 train mean loss=0.308256536722\n", | |
"epoch:19100 train mean loss=0.308253377676\n", | |
"epoch:19200 train mean loss=0.308115869761\n", | |
"epoch:19300 train mean loss=0.308184027672\n", | |
"epoch:19400 train mean loss=0.308013409376\n", | |
"epoch:19500 train mean loss=0.30789411068\n", | |
"epoch:19600 train mean loss=0.307871729136\n", | |
"epoch:19700 train mean loss=0.308270901442\n", | |
"epoch:19800 train mean loss=0.30770689249\n", | |
"epoch:19900 train mean loss=0.307707577944\n", | |
"epoch:20000 train mean loss=0.327270388603\n", | |
"epoch:20100 train mean loss=0.307528465986\n", | |
"epoch:20200 train mean loss=0.307558685541\n", | |
"epoch:20300 train mean loss=0.307413429022\n", | |
"epoch:20400 train mean loss=0.30802705884\n", | |
"epoch:20500 train mean loss=0.3073117733\n", | |
"epoch:20600 train mean loss=0.307250380516\n", | |
"epoch:20700 train mean loss=0.307330220938\n", | |
"epoch:20800 train mean loss=0.307140320539\n", | |
"epoch:20900 train mean loss=0.313841462135\n", | |
"epoch:21000 train mean loss=0.307033151388\n", | |
"epoch:21100 train mean loss=0.30697080493\n", | |
"epoch:21200 train mean loss=0.30835917592\n", | |
"epoch:21300 train mean loss=0.311429440975\n", | |
"epoch:21400 train mean loss=0.306810289621\n", | |
"epoch:21500 train mean loss=0.306930065155\n", | |
"epoch:21600 train mean loss=0.306723326445\n", | |
"epoch:21700 train mean loss=0.306637555361\n", | |
"epoch:21800 train mean loss=0.307242542505\n", | |
"epoch:21900 train mean loss=0.308094322681\n", | |
"epoch:22000 train mean loss=0.306472569704\n", | |
"epoch:22100 train mean loss=0.306425184011\n", | |
"epoch:22200 train mean loss=0.306489676237\n", | |
"epoch:22300 train mean loss=0.306289583445\n", | |
"epoch:22400 train mean loss=0.306257933378\n", | |
"epoch:22500 train mean loss=0.306520462036\n", | |
"epoch:22600 train mean loss=0.306117296219\n", | |
"epoch:22700 train mean loss=0.307666122913\n", | |
"epoch:22800 train mean loss=0.305995106697\n", | |
"epoch:22900 train mean loss=0.306275159121\n", | |
"epoch:23000 train mean loss=0.305876225233\n", | |
"epoch:23100 train mean loss=0.305820375681\n", | |
"epoch:23200 train mean loss=0.306178838015\n", | |
"epoch:23300 train mean loss=0.30633610487\n", | |
"epoch:23400 train mean loss=0.305618166924\n", | |
"epoch:23500 train mean loss=0.327315598726\n", | |
"epoch:23600 train mean loss=0.305489331484\n", | |
"epoch:23700 train mean loss=0.305429697037\n", | |
"epoch:23800 train mean loss=0.305358588696\n", | |
"epoch:23900 train mean loss=0.305315464735\n", | |
"epoch:24000 train mean loss=0.306744635105\n", | |
"epoch:24100 train mean loss=0.305161654949\n", | |
"epoch:24200 train mean loss=0.306147366762\n", | |
"epoch:24300 train mean loss=0.305026143789\n", | |
"epoch:24400 train mean loss=0.305058479309\n", | |
"epoch:24500 train mean loss=0.315707623959\n", | |
"epoch:24600 train mean loss=0.304824322462\n", | |
"epoch:24700 train mean loss=0.304865241051\n", | |
"epoch:24800 train mean loss=0.304689764977\n", | |
"epoch:24900 train mean loss=0.306201130152\n", | |
"epoch:25000 train mean loss=0.304550260305\n", | |
"epoch:25100 train mean loss=0.304488837719\n", | |
"epoch:25200 train mean loss=0.311301916838\n", | |
"epoch:25300 train mean loss=0.304344534874\n", | |
"epoch:25400 train mean loss=0.304272085428\n", | |
"epoch:25500 train mean loss=0.304375559092\n", | |
"epoch:25600 train mean loss=0.304685622454\n", | |
"epoch:25700 train mean loss=0.304057300091\n", | |
"epoch:25800 train mean loss=0.30401173234\n", | |
"epoch:25900 train mean loss=0.303913414478\n", | |
"epoch:26000 train mean loss=0.30386030674\n", | |
"epoch:26100 train mean loss=0.314157217741\n", | |
"epoch:26200 train mean loss=0.303698331118\n", | |
"epoch:26300 train mean loss=0.303641915321\n", | |
"epoch:26400 train mean loss=0.303722947836\n", | |
"epoch:26500 train mean loss=0.303478807211\n", | |
"epoch:26600 train mean loss=0.303917139769\n", | |
"epoch:26700 train mean loss=0.303332239389\n", | |
"epoch:26800 train mean loss=0.303517878056\n", | |
"epoch:26900 train mean loss=0.303733468056\n", | |
"epoch:27000 train mean loss=0.30311319232\n", | |
"epoch:27100 train mean loss=0.303036987782\n", | |
"epoch:27200 train mean loss=0.303193837404\n", | |
"epoch:27300 train mean loss=0.30289003253\n", | |
"epoch:27400 train mean loss=0.302912324667\n", | |
"epoch:27500 train mean loss=0.302889496088\n", | |
"epoch:27600 train mean loss=0.302685052156\n", | |
"epoch:27700 train mean loss=0.302594870329\n", | |
"epoch:27800 train mean loss=0.302679657936\n", | |
"epoch:27900 train mean loss=0.302458643913\n", | |
"epoch:28000 train mean loss=0.302377492189\n", | |
"epoch:28100 train mean loss=0.302309781313\n", | |
"epoch:28200 train mean loss=0.305227398872\n", | |
"epoch:28300 train mean loss=0.302150338888\n", | |
"epoch:28400 train mean loss=0.302149742842\n", | |
"epoch:28500 train mean loss=0.302710354328\n", | |
"epoch:28600 train mean loss=0.302416056395\n", | |
"epoch:28700 train mean loss=0.305107146502\n", | |
"epoch:28800 train mean loss=0.313704073429\n", | |
"epoch:28900 train mean loss=0.301713943481\n", | |
"epoch:29000 train mean loss=0.30163320899\n", | |
"epoch:29100 train mean loss=0.30195248127\n", | |
"epoch:29200 train mean loss=0.301480799913\n", | |
"epoch:29300 train mean loss=0.301409363747\n", | |
"epoch:29400 train mean loss=0.311064451933\n", | |
"epoch:29500 train mean loss=0.301260054111\n", | |
"epoch:29600 train mean loss=0.301843434572\n", | |
"epoch:29700 train mean loss=0.309216678143\n", | |
"epoch:29800 train mean loss=0.301059812307\n", | |
"epoch:29900 train mean loss=0.300995200872\n", | |
"epoch:30000 train mean loss=0.302341520786\n", | |
"epoch:30100 train mean loss=0.300906777382\n", | |
"epoch:30200 train mean loss=0.300751328468\n", | |
"epoch:30300 train mean loss=0.300980657339\n", | |
"epoch:30400 train mean loss=0.300626903772\n", | |
"epoch:30500 train mean loss=0.300533741713\n", | |
"epoch:30600 train mean loss=0.301971584558\n", | |
"epoch:30700 train mean loss=0.30038946867\n", | |
"epoch:30800 train mean loss=0.300339192152\n", | |
"epoch:30900 train mean loss=0.302649110556\n", | |
"epoch:31000 train mean loss=0.300175487995\n", | |
"epoch:31100 train mean loss=0.300122857094\n", | |
"epoch:31200 train mean loss=0.310733556747\n", | |
"epoch:31300 train mean loss=0.299963474274\n", | |
"epoch:31400 train mean loss=0.299942463636\n", | |
"epoch:31500 train mean loss=0.303368359804\n", | |
"epoch:31600 train mean loss=0.299751162529\n", | |
"epoch:31700 train mean loss=0.300648093224\n", | |
"epoch:31800 train mean loss=0.299620181322\n", | |
"epoch:31900 train mean loss=0.299550026655\n", | |
"epoch:32000 train mean loss=0.299581199884\n", | |
"epoch:32100 train mean loss=0.299404680729\n", | |
"epoch:32200 train mean loss=0.300023168325\n", | |
"epoch:32300 train mean loss=0.299285322428\n", | |
"epoch:32400 train mean loss=0.299195408821\n", | |
"epoch:32500 train mean loss=0.300106972456\n", | |
"epoch:32600 train mean loss=0.299063831568\n", | |
"epoch:32700 train mean loss=0.298996120691\n", | |
"epoch:32800 train mean loss=0.301807135344\n", | |
"epoch:32900 train mean loss=0.29896324873\n", | |
"epoch:33000 train mean loss=0.298790425062\n", | |
"epoch:33100 train mean loss=0.298866122961\n", | |
"epoch:33200 train mean loss=0.298656284809\n", | |
"epoch:33300 train mean loss=0.298620909452\n", | |
"epoch:33400 train mean loss=0.301293730736\n", | |
"epoch:33500 train mean loss=0.298484444618\n", | |
"epoch:33600 train mean loss=0.298389077187\n", | |
"epoch:33700 train mean loss=0.298373103142\n", | |
"epoch:33800 train mean loss=0.298268288374\n", | |
"epoch:33900 train mean loss=0.298193216324\n", | |
"epoch:34000 train mean loss=0.298143565655\n", | |
"epoch:34100 train mean loss=0.307447820902\n", | |
"epoch:34200 train mean loss=0.297997444868\n", | |
"epoch:34300 train mean loss=0.297934323549\n", | |
"epoch:34400 train mean loss=0.298451930285\n", | |
"epoch:34500 train mean loss=0.297804027796\n", | |
"epoch:34600 train mean loss=0.297748446465\n", | |
"epoch:34700 train mean loss=0.297688066959\n", | |
"epoch:34800 train mean loss=0.308348745108\n", | |
"epoch:34900 train mean loss=0.297550559044\n", | |
"epoch:35000 train mean loss=0.297623038292\n", | |
"epoch:35100 train mean loss=0.299623310566\n", | |
"epoch:35200 train mean loss=0.297395825386\n", | |
"epoch:35300 train mean loss=0.297301203012\n", | |
"epoch:35400 train mean loss=0.297259509563\n", | |
"epoch:35500 train mean loss=0.298073500395\n", | |
"epoch:35600 train mean loss=0.297115474939\n", | |
"epoch:35700 train mean loss=0.297309368849\n", | |
"epoch:35800 train mean loss=0.297997355461\n", | |
"epoch:35900 train mean loss=0.30123513937\n", | |
"epoch:36000 train mean loss=0.302880972624\n", | |
"epoch:36100 train mean loss=0.296812802553\n", | |
"epoch:36200 train mean loss=0.29674872756\n", | |
"epoch:36300 train mean loss=0.307402461767\n", | |
"epoch:36400 train mean loss=0.296707242727\n", | |
"epoch:36500 train mean loss=0.296569257975\n", | |
"epoch:36600 train mean loss=0.296524018049\n", | |
"epoch:36700 train mean loss=0.297457069159\n", | |
"epoch:36800 train mean loss=0.296395897865\n", | |
"epoch:36900 train mean loss=0.296792954206\n", | |
"epoch:37000 train mean loss=0.298014909029\n", | |
"epoch:37100 train mean loss=0.296293199062\n", | |
"epoch:37200 train mean loss=0.296202689409\n", | |
"epoch:37300 train mean loss=0.296123474836\n", | |
"epoch:37400 train mean loss=0.296047776937\n", | |
"epoch:37500 train mean loss=0.29599031806\n", | |
"epoch:37600 train mean loss=0.295957833529\n", | |
"epoch:37700 train mean loss=0.307167738676\n", | |
"epoch:37800 train mean loss=0.295820564032\n", | |
"epoch:37900 train mean loss=0.295789271593\n", | |
"epoch:38000 train mean loss=0.295714735985\n", | |
"epoch:38100 train mean loss=0.295713573694\n", | |
"epoch:38200 train mean loss=0.296147704124\n", | |
"epoch:38300 train mean loss=0.295808196068\n", | |
"epoch:38400 train mean loss=0.295494288206\n", | |
"epoch:38500 train mean loss=0.295455008745\n", | |
"epoch:38600 train mean loss=0.295552790165\n", | |
"epoch:38700 train mean loss=0.295327395201\n", | |
"epoch:38800 train mean loss=0.317661732435\n", | |
"epoch:38900 train mean loss=0.295218706131\n", | |
"epoch:39000 train mean loss=0.299768358469\n", | |
"epoch:39100 train mean loss=0.295127242804\n", | |
"epoch:39200 train mean loss=0.295069098473\n", | |
"epoch:39300 train mean loss=0.295096039772\n", | |
"epoch:39400 train mean loss=0.29495254159\n", | |
"epoch:39500 train mean loss=0.294953644276\n", | |
"epoch:39600 train mean loss=0.294852524996\n", | |
"epoch:39700 train mean loss=0.299119263887\n", | |
"epoch:39800 train mean loss=0.294748276472\n", | |
"epoch:39900 train mean loss=0.294712215662\n", | |
"epoch:40000 train mean loss=0.294731557369\n", | |
"epoch:40100 train mean loss=0.294594615698\n", | |
"epoch:40200 train mean loss=0.294544219971\n", | |
"epoch:40300 train mean loss=0.314175754786\n", | |
"epoch:40400 train mean loss=0.294441640377\n", | |
"epoch:40500 train mean loss=0.294942080975\n", | |
"epoch:40600 train mean loss=0.295029520988\n", | |
"epoch:40700 train mean loss=0.294298529625\n", | |
"epoch:40800 train mean loss=0.294265091419\n", | |
"epoch:40900 train mean loss=0.294349759817\n", | |
"epoch:41000 train mean loss=0.306304693222\n", | |
"epoch:41100 train mean loss=0.294096827507\n", | |
"epoch:41200 train mean loss=0.294087916613\n", | |
"epoch:41300 train mean loss=0.300193369389\n", | |
"epoch:41400 train mean loss=0.293954223394\n", | |
"epoch:41500 train mean loss=0.293906062841\n", | |
"epoch:41600 train mean loss=0.29413741827\n", | |
"epoch:41700 train mean loss=0.295879274607\n", | |
"epoch:41800 train mean loss=0.293755233288\n", | |
"epoch:41900 train mean loss=0.293728470802\n", | |
"epoch:42000 train mean loss=0.298869729042\n", | |
"epoch:42100 train mean loss=0.293612837791\n", | |
"epoch:42200 train mean loss=0.293574541807\n", | |
"epoch:42300 train mean loss=0.294453859329\n", | |
"epoch:42400 train mean loss=0.293473601341\n", | |
"epoch:42500 train mean loss=0.297279566526\n", | |
"epoch:42600 train mean loss=0.293381243944\n", | |
"epoch:42700 train mean loss=0.294555008411\n", | |
"epoch:42800 train mean loss=0.293386310339\n", | |
"epoch:42900 train mean loss=0.293245047331\n", | |
"epoch:43000 train mean loss=0.29320409894\n", | |
"epoch:43100 train mean loss=0.293887257576\n", | |
"epoch:43200 train mean loss=0.296142578125\n", | |
"epoch:43300 train mean loss=0.293074965477\n", | |
"epoch:43400 train mean loss=0.293064892292\n", | |
"epoch:43500 train mean loss=0.293043881655\n", | |
"epoch:43600 train mean loss=0.292933791876\n", | |
"epoch:43700 train mean loss=0.293460249901\n", | |
"epoch:43800 train mean loss=0.309539079666\n", | |
"epoch:43900 train mean loss=0.292802125216\n", | |
"epoch:44000 train mean loss=0.293288260698\n", | |
"epoch:44100 train mean loss=0.292713463306\n", | |
"epoch:44200 train mean loss=0.293365508318\n", | |
"epoch:44300 train mean loss=0.292625814676\n", | |
"epoch:44400 train mean loss=0.29267424345\n", | |
"epoch:44500 train mean loss=0.293920695782\n", | |
"epoch:44600 train mean loss=0.292497456074\n", | |
"epoch:44700 train mean loss=0.292532473803\n", | |
"epoch:44800 train mean loss=0.292414575815\n", | |
"epoch:44900 train mean loss=0.292801588774\n", | |
"epoch:45000 train mean loss=0.292331457138\n", | |
"epoch:45100 train mean loss=0.292316317558\n", | |
"epoch:45200 train mean loss=0.292837500572\n", | |
"epoch:45300 train mean loss=0.292220532894\n", | |
"epoch:45400 train mean loss=0.292164713144\n", | |
"epoch:45500 train mean loss=0.29532968998\n", | |
"epoch:45600 train mean loss=0.292092591524\n", | |
"epoch:45700 train mean loss=0.292056530714\n", | |
"epoch:45800 train mean loss=0.292046755552\n", | |
"epoch:45900 train mean loss=0.291977256536\n", | |
"epoch:46000 train mean loss=0.291925936937\n", | |
"epoch:46100 train mean loss=0.293803840876\n", | |
"epoch:46200 train mean loss=0.291845291853\n", | |
"epoch:46300 train mean loss=0.300246447325\n", | |
"epoch:46400 train mean loss=0.291764855385\n", | |
"epoch:46500 train mean loss=0.291726797819\n", | |
"epoch:46600 train mean loss=0.291685223579\n", | |
"epoch:46700 train mean loss=0.291653990746\n", | |
"epoch:46800 train mean loss=0.291615366936\n", | |
"epoch:46900 train mean loss=0.291570425034\n", | |
"epoch:47000 train mean loss=0.292142659426\n", | |
"epoch:47100 train mean loss=0.303349345922\n", | |
"epoch:47200 train mean loss=0.291460245848\n", | |
"epoch:47300 train mean loss=0.29142421484\n", | |
"epoch:47400 train mean loss=0.313145369291\n", | |
"epoch:47500 train mean loss=0.291339725256\n", | |
"epoch:47600 train mean loss=0.291306138039\n", | |
"epoch:47700 train mean loss=0.291265368462\n", | |
"epoch:47800 train mean loss=0.292409271002\n", | |
"epoch:47900 train mean loss=0.291192978621\n", | |
"epoch:48000 train mean loss=0.291167706251\n", | |
"epoch:48100 train mean loss=0.293959289789\n", | |
"epoch:48200 train mean loss=0.291084080935\n", | |
"epoch:48300 train mean loss=0.291054308414\n", | |
"epoch:48400 train mean loss=0.302320361137\n", | |
"epoch:48500 train mean loss=0.290974408388\n", | |
"epoch:48600 train mean loss=0.291046977043\n", | |
"epoch:48700 train mean loss=0.290899753571\n", | |
"epoch:48800 train mean loss=0.29087126255\n", | |
"epoch:48900 train mean loss=0.301152408123\n", | |
"epoch:49000 train mean loss=0.290869742632\n", | |
"epoch:49100 train mean loss=0.290761172771\n", | |
"epoch:49200 train mean loss=0.291226267815\n", | |
"epoch:49300 train mean loss=0.290697187185\n", | |
"epoch:49400 train mean loss=0.290658682585\n", | |
"epoch:49500 train mean loss=0.290621548891\n", | |
"epoch:49600 train mean loss=0.290820926428\n", | |
"epoch:49700 train mean loss=0.293107151985\n", | |
"epoch:49800 train mean loss=0.290510058403\n", | |
"epoch:49900 train mean loss=0.29054531455\n", | |
"epoch:50000 train mean loss=0.292022317648\n" | |
] | |
} | |
], | |
"source": [ | |
"# 確率的勾配降下法で学習させる際の1回分のバッチサイズ\n", | |
"batchsize = 11\n", | |
"# 学習の繰り返し回数\n", | |
"n_epoch = 50000\n", | |
"N = X.shape[0]\n", | |
"# Learning loop\n", | |
"for epoch in xrange(1, n_epoch+1):\n", | |
" #print 'epoch', epoch\n", | |
"\n", | |
" # N個の順番をランダムに並び替える\n", | |
" perm = np.random.permutation(N)\n", | |
" sum_loss = 0\n", | |
" # 0〜Nまでのデータをバッチサイズごとに使って学習\n", | |
" for i in xrange(0, N, batchsize):\n", | |
" x_batch = X[perm[i:i+batchsize]]\n", | |
" y_batch = y[perm[i:i+batchsize]]\n", | |
"\n", | |
" # 勾配を初期化\n", | |
" optimizer.zero_grads()\n", | |
" # 順伝播させて誤差を算出\n", | |
" loss = forward(x_batch, y_batch)\n", | |
" # 誤差逆伝播で勾配を計算\n", | |
" loss.backward()\n", | |
" optimizer.weight_decay(0.01)\n", | |
" optimizer.update()\n", | |
" sum_loss += loss.data * batchsize\n", | |
"\n", | |
" # 訓練データの誤差と、正解精度を表示\n", | |
" if epoch % 100 == 0:\n", | |
" print 'epoch:{} train mean loss={}'.format(epoch, sum_loss / N)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[ 6.07095146]\n", | |
" [ -0.90871316]\n", | |
" [ -0.47295213]\n", | |
" [ 316.75616455]\n", | |
" [ 443.15646362]\n", | |
" [ -0.95780635]\n", | |
" [ 199.78878784]\n", | |
" [ 70.07402039]\n", | |
" [ -0.80237651]\n", | |
" [ 15.18830585]\n", | |
" [ 16.46421432]]\n", | |
"[[ 5.80544519e+00]\n", | |
" [ -2.21579187e-02]\n", | |
" [ -3.42641264e-01]\n", | |
" [ 3.16766022e+02]\n", | |
" [ 4.43226501e+02]\n", | |
" [ -6.06474221e-01]\n", | |
" [ 1.99794052e+02]\n", | |
" [ 6.99967651e+01]\n", | |
" [ -2.24294662e+00]\n", | |
" [ 1.52209854e+01]\n", | |
" [ 1.68022785e+01]]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.29077136516571045, dtype=float32)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x_test = Variable(X[range(11)])\n", | |
"t = Variable(y[range(11)])\n", | |
"h1 = expo(model.l1(x_test))\n", | |
"y_test = model.l2(h1)\n", | |
"print y_test.data\n", | |
"print t.data\n", | |
"F.mean_squared_error(y_test, t).data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 2.24651504e+00, 1.13109624e+00, -2.16289815e-02],\n", | |
" [ 1.10743952e+00, 4.99774486e-01, 8.01726012e-04]], dtype=float32)" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.l1.W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0.59255338, 0.89250612]], dtype=float32)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.l2.W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment