Created
October 30, 2017 13:24
-
-
Save regonn/f7db4fe43559110f5e3b2fd65e70cc7f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# AWS の Deep learning 用 AMI で MXNet を動かす\n", | |
"\n", | |
"## モチベーション\n", | |
"- ローカルのパソコンだと処理速度に限界を感じてきた\n", | |
"- かといって、自作PCを作るモチベーションはあまりない\n", | |
"- AWSなどのクラウドサーバーのインスタンスでできると良さそう\n", | |
"- Amazon は MXNet という機械学習ライブラリを公式にサポートしており、AWSとの相性は良い\n", | |
" - [MXNet とは \\| AWS](https://aws.amazon.com/jp/mxnet/)\n", | |
" \n", | |
"## 今回参考にした記事とコード\n", | |
"- [MXNetをAmazon Deep Learning AMIとスポットインスタンスのGPUで試してみる | Developers\\.IO](https://dev.classmethod.jp/cloud/aws/mxnet-on-amazon-deep-learning-ami-with-gpu-instance/)\n", | |
"- [MXNetのmnistチュートリアル dmlc/mxnet\\-notebooks](https://github.com/dmlc/mxnet-notebooks/blob/master/python/tutorials/mnist.ipynb)\n", | |
"\n", | |
"## インスタンスを構築する\n", | |
"AWS上で行う\n", | |
"\n", | |
"## sshでログイン\n", | |
"\n", | |
"```\n", | |
"$ ssh -i your-key.pem [email protected] -L 8888:localhost:8888\n", | |
"```\n", | |
"\n", | |
"`-L` オプションを付けることで、ローカルの8888番ポートからアクセスできるようになる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import os\n", | |
"import urllib\n", | |
"import gzip\n", | |
"import struct\n", | |
"\n", | |
"# mnist データをダウンロードしてくる\n", | |
"def download_data(url, force_download=True): \n", | |
" fname = url.split(\"/\")[-1]\n", | |
" if force_download or not os.path.exists(fname):\n", | |
" urllib.urlretrieve(url, fname)\n", | |
" return fname\n", | |
"\n", | |
"def read_data(label_url, image_url):\n", | |
" with gzip.open(download_data(label_url)) as flbl:\n", | |
" magic, num = struct.unpack(\">II\", flbl.read(8))\n", | |
" label = np.fromstring(flbl.read(), dtype=np.int8)\n", | |
" with gzip.open(download_data(image_url), 'rb') as fimg:\n", | |
" magic, num, rows, cols = struct.unpack(\">IIII\", fimg.read(16))\n", | |
" image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)\n", | |
" return (label, image)\n", | |
"\n", | |
"path='http://yann.lecun.com/exdb/mnist/'\n", | |
"(train_lbl, train_img) = read_data(\n", | |
" path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')\n", | |
"(val_lbl, val_img) = read_data(\n", | |
" path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import mxnet as mx\n", | |
"\n", | |
"def to4d(img):\n", | |
" return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255\n", | |
"\n", | |
"batch_size = 100\n", | |
"train_iter = mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)\n", | |
"val_iter = mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## まずは CPU で MLP(Multi Layer Perceptron:多層パーセプトロン)を使った学習" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n", | |
" -->\n", | |
"<!-- Title: plot Pages: 1 -->\n", | |
"<svg width=\"214pt\" height=\"829pt\"\n", | |
" viewBox=\"0.00 0.00 214.00 829.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 825)\">\n", | |
"<title>plot</title>\n", | |
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-825 210,-825 210,4 -4,4\"/>\n", | |
"<!-- data -->\n", | |
"<g id=\"node1\" class=\"node\"><title>data</title>\n", | |
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"47\" cy=\"-29\" rx=\"47\" ry=\"29\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-25.3\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n", | |
"</g>\n", | |
"<!-- flatten0 -->\n", | |
"<g id=\"node2\" class=\"node\"><title>flatten0</title>\n", | |
"<polygon fill=\"#fdb462\" stroke=\"black\" points=\"94,-167 -7.10543e-15,-167 -7.10543e-15,-109 94,-109 94,-167\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-134.3\" font-family=\"Times,serif\" font-size=\"14.00\">flatten0</text>\n", | |
"</g>\n", | |
"<!-- flatten0->data -->\n", | |
"<g id=\"edge1\" class=\"edge\"><title>flatten0->data</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-98.5824C47,-85.2841 47,-70.632 47,-58.2967\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-108.887 42.5001,-98.887 47,-103.887 47.0001,-98.887 47.0001,-98.887 47.0001,-98.887 47,-103.887 51.5001,-98.8871 47,-108.887 47,-108.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"71\" y=\"-79.8\" font-family=\"Times,serif\" font-size=\"14.00\">1x28x28</text>\n", | |
"</g>\n", | |
"<!-- fc1 -->\n", | |
"<g id=\"node3\" class=\"node\"><title>fc1</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-276 -7.10543e-15,-276 -7.10543e-15,-218 94,-218 94,-276\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-250.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-235.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n", | |
"</g>\n", | |
"<!-- fc1->flatten0 -->\n", | |
"<g id=\"edge2\" class=\"edge\"><title>fc1->flatten0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-207.582C47,-194.284 47,-179.632 47,-167.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-217.887 42.5001,-207.887 47,-212.887 47.0001,-207.887 47.0001,-207.887 47.0001,-207.887 47,-212.887 51.5001,-207.887 47,-217.887 47,-217.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-188.8\" font-family=\"Times,serif\" font-size=\"14.00\">784</text>\n", | |
"</g>\n", | |
"<!-- relu1 -->\n", | |
"<g id=\"node4\" class=\"node\"><title>relu1</title>\n", | |
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-385 -7.10543e-15,-385 -7.10543e-15,-327 94,-327 94,-385\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-359.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-344.8\" font-family=\"Times,serif\" font-size=\"14.00\">relu</text>\n", | |
"</g>\n", | |
"<!-- relu1->fc1 -->\n", | |
"<g id=\"edge3\" class=\"edge\"><title>relu1->fc1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-316.582C47,-303.284 47,-288.632 47,-276.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-326.887 42.5001,-316.887 47,-321.887 47.0001,-316.887 47.0001,-316.887 47.0001,-316.887 47,-321.887 51.5001,-316.887 47,-326.887 47,-326.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-297.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n", | |
"</g>\n", | |
"<!-- fc2 -->\n", | |
"<g id=\"node5\" class=\"node\"><title>fc2</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-494 -7.10543e-15,-494 -7.10543e-15,-436 94,-436 94,-494\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-468.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-453.8\" font-family=\"Times,serif\" font-size=\"14.00\">64</text>\n", | |
"</g>\n", | |
"<!-- fc2->relu1 -->\n", | |
"<g id=\"edge4\" class=\"edge\"><title>fc2->relu1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-425.582C47,-412.284 47,-397.632 47,-385.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-435.887 42.5001,-425.887 47,-430.887 47.0001,-425.887 47.0001,-425.887 47.0001,-425.887 47,-430.887 51.5001,-425.887 47,-435.887 47,-435.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-406.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n", | |
"</g>\n", | |
"<!-- relu2 -->\n", | |
"<g id=\"node6\" class=\"node\"><title>relu2</title>\n", | |
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-603 -7.10543e-15,-603 -7.10543e-15,-545 94,-545 94,-603\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-577.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-562.8\" font-family=\"Times,serif\" font-size=\"14.00\">relu</text>\n", | |
"</g>\n", | |
"<!-- relu2->fc2 -->\n", | |
"<g id=\"edge5\" class=\"edge\"><title>relu2->fc2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-534.582C47,-521.284 47,-506.632 47,-494.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-544.887 42.5001,-534.887 47,-539.887 47.0001,-534.887 47.0001,-534.887 47.0001,-534.887 47,-539.887 51.5001,-534.887 47,-544.887 47,-544.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"54\" y=\"-515.8\" font-family=\"Times,serif\" font-size=\"14.00\">64</text>\n", | |
"</g>\n", | |
"<!-- fc3 -->\n", | |
"<g id=\"node7\" class=\"node\"><title>fc3</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-712 -7.10543e-15,-712 -7.10543e-15,-654 94,-654 94,-712\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-686.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-671.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- fc3->relu2 -->\n", | |
"<g id=\"edge6\" class=\"edge\"><title>fc3->relu2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-643.582C47,-630.284 47,-615.632 47,-603.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-653.887 42.5001,-643.887 47,-648.887 47.0001,-643.887 47.0001,-643.887 47.0001,-643.887 47,-648.887 51.5001,-643.887 47,-653.887 47,-653.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"54\" y=\"-624.8\" font-family=\"Times,serif\" font-size=\"14.00\">64</text>\n", | |
"</g>\n", | |
"<!-- softmax_label -->\n", | |
"<g id=\"node8\" class=\"node\"><title>softmax_label</title>\n", | |
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"159\" cy=\"-683\" rx=\"47\" ry=\"29\"/>\n", | |
"<text text-anchor=\"middle\" x=\"159\" y=\"-679.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax_label</text>\n", | |
"</g>\n", | |
"<!-- softmax -->\n", | |
"<g id=\"node9\" class=\"node\"><title>softmax</title>\n", | |
"<polygon fill=\"#fccde5\" stroke=\"black\" points=\"170,-821 76,-821 76,-763 170,-763 170,-821\"/>\n", | |
"<text text-anchor=\"middle\" x=\"123\" y=\"-788.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax</text>\n", | |
"</g>\n", | |
"<!-- softmax->fc3 -->\n", | |
"<g id=\"edge7\" class=\"edge\"><title>softmax->fc3</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M97.1082,-754.547C87.3017,-740.741 76.2938,-725.243 67.0986,-712.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"103.032,-762.887 93.5724,-757.34 100.137,-758.811 97.2411,-754.734 97.2411,-754.734 97.2411,-754.734 100.137,-758.811 100.91,-752.128 103.032,-762.887 103.032,-762.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"97\" y=\"-733.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- softmax->softmax_label -->\n", | |
"<g id=\"edge8\" class=\"edge\"><title>softmax->softmax_label</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M135.713,-753.215C140.333,-739.483 145.463,-724.236 149.729,-711.555\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"132.459,-762.887 131.382,-751.974 134.053,-758.148 135.647,-753.409 135.647,-753.409 135.647,-753.409 134.053,-758.148 139.913,-754.844 132.459,-762.887 132.459,-762.887\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.dot.Digraph at 0x7f9c0fe18690>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Create a place holder variable for the input data\n", | |
"data = mx.sym.Variable('data')\n", | |
"data = mx.sym.Flatten(data=data)\n", | |
"\n", | |
"# The first fully-connected layer\n", | |
"fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)\n", | |
"# Apply relu to the output of the first fully-connnected layer\n", | |
"act1 = mx.sym.Activation(data=fc1, name='relu1', act_type=\"relu\")\n", | |
"\n", | |
"# The second fully-connected layer and the according activation function\n", | |
"fc2 = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden = 64)\n", | |
"act2 = mx.sym.Activation(data=fc2, name='relu2', act_type=\"relu\")\n", | |
"\n", | |
"# The thrid fully-connected layer, note that the hidden size should be 10, which is the number of unique digits\n", | |
"fc3 = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)\n", | |
"# The softmax and loss layer\n", | |
"mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')\n", | |
"\n", | |
"# We visualize the network structure with output size (the batch_size is ignored.)\n", | |
"shape = {\"data\" : (batch_size, 1, 28, 28)}\n", | |
"mx.viz.plot_network(symbol=mlp, shape=shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/lib/python2.7/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: \u001b[91mmxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n", | |
" import sys\n", | |
"/usr/lib/python2.7/dist-packages/mxnet-0.11.0-py2.7.egg/mxnet/model.py:547: DeprecationWarning: \u001b[91mCalling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.\u001b[0m\n", | |
" self.initializer(k, v)\n", | |
"INFO:root:Start training with [cpu(0)]\n", | |
"INFO:root:Epoch[0] Batch [200]\tSpeed: 39874.55 samples/sec\taccuracy=0.110900\n", | |
"INFO:root:Epoch[0] Batch [400]\tSpeed: 38556.08 samples/sec\taccuracy=0.112250\n", | |
"INFO:root:Epoch[0] Batch [600]\tSpeed: 40741.83 samples/sec\taccuracy=0.138200\n", | |
"INFO:root:Epoch[0] Resetting Data Iterator\n", | |
"INFO:root:Epoch[0] Time cost=1.530\n", | |
"INFO:root:Epoch[0] Validation-accuracy=0.217600\n", | |
"INFO:root:Epoch[1] Batch [200]\tSpeed: 39664.76 samples/sec\taccuracy=0.419750\n", | |
"INFO:root:Epoch[1] Batch [400]\tSpeed: 37210.13 samples/sec\taccuracy=0.747450\n", | |
"INFO:root:Epoch[1] Batch [600]\tSpeed: 38671.70 samples/sec\taccuracy=0.829000\n", | |
"INFO:root:Epoch[1] Resetting Data Iterator\n", | |
"INFO:root:Epoch[1] Time cost=1.569\n", | |
"INFO:root:Epoch[1] Validation-accuracy=0.857800\n", | |
"INFO:root:Epoch[2] Batch [200]\tSpeed: 37269.50 samples/sec\taccuracy=0.861850\n", | |
"INFO:root:Epoch[2] Batch [400]\tSpeed: 39785.30 samples/sec\taccuracy=0.888600\n", | |
"INFO:root:Epoch[2] Batch [600]\tSpeed: 38696.32 samples/sec\taccuracy=0.904550\n", | |
"INFO:root:Epoch[2] Resetting Data Iterator\n", | |
"INFO:root:Epoch[2] Time cost=1.566\n", | |
"INFO:root:Epoch[2] Validation-accuracy=0.916000\n", | |
"INFO:root:Epoch[3] Batch [200]\tSpeed: 37400.46 samples/sec\taccuracy=0.917950\n", | |
"INFO:root:Epoch[3] Batch [400]\tSpeed: 37174.52 samples/sec\taccuracy=0.930800\n", | |
"INFO:root:Epoch[3] Batch [600]\tSpeed: 36195.49 samples/sec\taccuracy=0.935700\n", | |
"INFO:root:Epoch[3] Resetting Data Iterator\n", | |
"INFO:root:Epoch[3] Time cost=1.635\n", | |
"INFO:root:Epoch[3] Validation-accuracy=0.940000\n", | |
"INFO:root:Epoch[4] Batch [200]\tSpeed: 36301.26 samples/sec\taccuracy=0.940950\n", | |
"INFO:root:Epoch[4] Batch [400]\tSpeed: 31858.09 samples/sec\taccuracy=0.949800\n", | |
"INFO:root:Epoch[4] Batch [600]\tSpeed: 36146.50 samples/sec\taccuracy=0.949650\n", | |
"INFO:root:Epoch[4] Resetting Data Iterator\n", | |
"INFO:root:Epoch[4] Time cost=1.742\n", | |
"INFO:root:Epoch[4] Validation-accuracy=0.951600\n", | |
"INFO:root:Epoch[5] Batch [200]\tSpeed: 36306.79 samples/sec\taccuracy=0.952900\n", | |
"INFO:root:Epoch[5] Batch [400]\tSpeed: 37419.27 samples/sec\taccuracy=0.959650\n", | |
"INFO:root:Epoch[5] Batch [600]\tSpeed: 36265.92 samples/sec\taccuracy=0.959500\n", | |
"INFO:root:Epoch[5] Resetting Data Iterator\n", | |
"INFO:root:Epoch[5] Time cost=1.647\n", | |
"INFO:root:Epoch[5] Validation-accuracy=0.956200\n", | |
"INFO:root:Epoch[6] Batch [200]\tSpeed: 36383.02 samples/sec\taccuracy=0.961200\n", | |
"INFO:root:Epoch[6] Batch [400]\tSpeed: 35857.33 samples/sec\taccuracy=0.966100\n", | |
"INFO:root:Epoch[6] Batch [600]\tSpeed: 36210.63 samples/sec\taccuracy=0.965350\n", | |
"INFO:root:Epoch[6] Resetting Data Iterator\n", | |
"INFO:root:Epoch[6] Time cost=1.670\n", | |
"INFO:root:Epoch[6] Validation-accuracy=0.957700\n", | |
"INFO:root:Epoch[7] Batch [200]\tSpeed: 36412.97 samples/sec\taccuracy=0.966700\n", | |
"INFO:root:Epoch[7] Batch [400]\tSpeed: 36228.08 samples/sec\taccuracy=0.970800\n", | |
"INFO:root:Epoch[7] Batch [600]\tSpeed: 36257.63 samples/sec\taccuracy=0.969250\n", | |
"INFO:root:Epoch[7] Resetting Data Iterator\n", | |
"INFO:root:Epoch[7] Time cost=1.663\n", | |
"INFO:root:Epoch[7] Validation-accuracy=0.960400\n", | |
"INFO:root:Epoch[8] Batch [200]\tSpeed: 36243.76 samples/sec\taccuracy=0.971000\n", | |
"INFO:root:Epoch[8] Batch [400]\tSpeed: 36209.45 samples/sec\taccuracy=0.974450\n", | |
"INFO:root:Epoch[8] Batch [600]\tSpeed: 36219.94 samples/sec\taccuracy=0.973150\n", | |
"INFO:root:Epoch[8] Resetting Data Iterator\n", | |
"INFO:root:Epoch[8] Time cost=1.666\n", | |
"INFO:root:Epoch[8] Validation-accuracy=0.961900\n", | |
"INFO:root:Epoch[9] Batch [200]\tSpeed: 36498.22 samples/sec\taccuracy=0.974500\n", | |
"INFO:root:Epoch[9] Batch [400]\tSpeed: 37476.41 samples/sec\taccuracy=0.977200\n", | |
"INFO:root:Epoch[9] Batch [600]\tSpeed: 36123.13 samples/sec\taccuracy=0.976100\n", | |
"INFO:root:Epoch[9] Resetting Data Iterator\n", | |
"INFO:root:Epoch[9] Time cost=1.645\n", | |
"INFO:root:Epoch[9] Validation-accuracy=0.963100\n" | |
] | |
} | |
], | |
"source": [ | |
"import logging\n", | |
"logging.getLogger().setLevel(logging.DEBUG)\n", | |
"\n", | |
"model = mx.model.FeedForward(\n", | |
" symbol = mlp, # network structure\n", | |
" num_epoch = 10, # number of data passes for training \n", | |
" learning_rate = 0.1 # learning rate of SGD \n", | |
")\n", | |
"\n", | |
"model.fit(\n", | |
" X=train_iter, # training data\n", | |
" eval_data=val_iter, # validation data\n", | |
" batch_end_callback = mx.callback.Speedometer(batch_size, 200) # output progress for each 200 data batches\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 続いて CNN の LeNet でGPUを使って計算" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### CNNとLeNetについては下の2つを読むと理解しやすい\n", | |
"- [\\[ディープラーニング\\] LeNet – Tech Memo](http://tecmemo.wpblog.jp/2017/03/19/dl_lenet/)\n", | |
"- [定番のConvolutional Neural Networkをゼロから理解する \\- DeepAge](https://deepage.net/deep_learning/2016/11/07/convolutional_neural_network.html)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n", | |
" -->\n", | |
"<!-- Title: plot Pages: 1 -->\n", | |
"<svg width=\"214pt\" height=\"1265pt\"\n", | |
" viewBox=\"0.00 0.00 214.00 1265.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 1261)\">\n", | |
"<title>plot</title>\n", | |
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-1261 210,-1261 210,4 -4,4\"/>\n", | |
"<!-- data -->\n", | |
"<g id=\"node1\" class=\"node\"><title>data</title>\n", | |
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"47\" cy=\"-29\" rx=\"47\" ry=\"29\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-25.3\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n", | |
"</g>\n", | |
"<!-- convolution0 -->\n", | |
"<g id=\"node2\" class=\"node\"><title>convolution0</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-167 -7.10543e-15,-167 -7.10543e-15,-109 94,-109 94,-167\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-141.8\" font-family=\"Times,serif\" font-size=\"14.00\">Convolution</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-126.8\" font-family=\"Times,serif\" font-size=\"14.00\">5x5/1, 20</text>\n", | |
"</g>\n", | |
"<!-- convolution0->data -->\n", | |
"<g id=\"edge1\" class=\"edge\"><title>convolution0->data</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-98.5824C47,-85.2841 47,-70.632 47,-58.2967\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-108.887 42.5001,-98.887 47,-103.887 47.0001,-98.887 47.0001,-98.887 47.0001,-98.887 47,-103.887 51.5001,-98.8871 47,-108.887 47,-108.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"71\" y=\"-79.8\" font-family=\"Times,serif\" font-size=\"14.00\">1x28x28</text>\n", | |
"</g>\n", | |
"<!-- activation0 -->\n", | |
"<g id=\"node3\" class=\"node\"><title>activation0</title>\n", | |
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-276 -7.10543e-15,-276 -7.10543e-15,-218 94,-218 94,-276\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-250.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-235.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", | |
"</g>\n", | |
"<!-- activation0->convolution0 -->\n", | |
"<g id=\"edge2\" class=\"edge\"><title>activation0->convolution0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-207.582C47,-194.284 47,-179.632 47,-167.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-217.887 42.5001,-207.887 47,-212.887 47.0001,-207.887 47.0001,-207.887 47.0001,-207.887 47,-212.887 51.5001,-207.887 47,-217.887 47,-217.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"74.5\" y=\"-188.8\" font-family=\"Times,serif\" font-size=\"14.00\">20x24x24</text>\n", | |
"</g>\n", | |
"<!-- pooling0 -->\n", | |
"<g id=\"node4\" class=\"node\"><title>pooling0</title>\n", | |
"<polygon fill=\"#80b1d3\" stroke=\"black\" points=\"94,-385 -7.10543e-15,-385 -7.10543e-15,-327 94,-327 94,-385\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-359.8\" font-family=\"Times,serif\" font-size=\"14.00\">Pooling</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-344.8\" font-family=\"Times,serif\" font-size=\"14.00\">max, 2x2/2x2</text>\n", | |
"</g>\n", | |
"<!-- pooling0->activation0 -->\n", | |
"<g id=\"edge3\" class=\"edge\"><title>pooling0->activation0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-316.582C47,-303.284 47,-288.632 47,-276.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-326.887 42.5001,-316.887 47,-321.887 47.0001,-316.887 47.0001,-316.887 47.0001,-316.887 47,-321.887 51.5001,-316.887 47,-326.887 47,-326.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"74.5\" y=\"-297.8\" font-family=\"Times,serif\" font-size=\"14.00\">20x24x24</text>\n", | |
"</g>\n", | |
"<!-- convolution1 -->\n", | |
"<g id=\"node5\" class=\"node\"><title>convolution1</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-494 -7.10543e-15,-494 -7.10543e-15,-436 94,-436 94,-494\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-468.8\" font-family=\"Times,serif\" font-size=\"14.00\">Convolution</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-453.8\" font-family=\"Times,serif\" font-size=\"14.00\">5x5/1, 50</text>\n", | |
"</g>\n", | |
"<!-- convolution1->pooling0 -->\n", | |
"<g id=\"edge4\" class=\"edge\"><title>convolution1->pooling0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-425.582C47,-412.284 47,-397.632 47,-385.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-435.887 42.5001,-425.887 47,-430.887 47.0001,-425.887 47.0001,-425.887 47.0001,-425.887 47,-430.887 51.5001,-425.887 47,-435.887 47,-435.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"74.5\" y=\"-406.8\" font-family=\"Times,serif\" font-size=\"14.00\">20x12x12</text>\n", | |
"</g>\n", | |
"<!-- activation1 -->\n", | |
"<g id=\"node6\" class=\"node\"><title>activation1</title>\n", | |
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-603 -7.10543e-15,-603 -7.10543e-15,-545 94,-545 94,-603\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-577.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-562.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", | |
"</g>\n", | |
"<!-- activation1->convolution1 -->\n", | |
"<g id=\"edge5\" class=\"edge\"><title>activation1->convolution1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-534.582C47,-521.284 47,-506.632 47,-494.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-544.887 42.5001,-534.887 47,-539.887 47.0001,-534.887 47.0001,-534.887 47.0001,-534.887 47,-539.887 51.5001,-534.887 47,-544.887 47,-544.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"67.5\" y=\"-515.8\" font-family=\"Times,serif\" font-size=\"14.00\">50x8x8</text>\n", | |
"</g>\n", | |
"<!-- pooling1 -->\n", | |
"<g id=\"node7\" class=\"node\"><title>pooling1</title>\n", | |
"<polygon fill=\"#80b1d3\" stroke=\"black\" points=\"94,-712 -7.10543e-15,-712 -7.10543e-15,-654 94,-654 94,-712\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-686.8\" font-family=\"Times,serif\" font-size=\"14.00\">Pooling</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-671.8\" font-family=\"Times,serif\" font-size=\"14.00\">max, 2x2/2x2</text>\n", | |
"</g>\n", | |
"<!-- pooling1->activation1 -->\n", | |
"<g id=\"edge6\" class=\"edge\"><title>pooling1->activation1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-643.582C47,-630.284 47,-615.632 47,-603.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-653.887 42.5001,-643.887 47,-648.887 47.0001,-643.887 47.0001,-643.887 47.0001,-643.887 47,-648.887 51.5001,-643.887 47,-653.887 47,-653.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"67.5\" y=\"-624.8\" font-family=\"Times,serif\" font-size=\"14.00\">50x8x8</text>\n", | |
"</g>\n", | |
"<!-- flatten1 -->\n", | |
"<g id=\"node8\" class=\"node\"><title>flatten1</title>\n", | |
"<polygon fill=\"#fdb462\" stroke=\"black\" points=\"94,-821 -7.10543e-15,-821 -7.10543e-15,-763 94,-763 94,-821\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-788.3\" font-family=\"Times,serif\" font-size=\"14.00\">flatten1</text>\n", | |
"</g>\n", | |
"<!-- flatten1->pooling1 -->\n", | |
"<g id=\"edge7\" class=\"edge\"><title>flatten1->pooling1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-752.582C47,-739.284 47,-724.632 47,-712.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-762.887 42.5001,-752.887 47,-757.887 47.0001,-752.887 47.0001,-752.887 47.0001,-752.887 47,-757.887 51.5001,-752.887 47,-762.887 47,-762.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"67.5\" y=\"-733.8\" font-family=\"Times,serif\" font-size=\"14.00\">50x4x4</text>\n", | |
"</g>\n", | |
"<!-- fullyconnected0 -->\n", | |
"<g id=\"node9\" class=\"node\"><title>fullyconnected0</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-930 -7.10543e-15,-930 -7.10543e-15,-872 94,-872 94,-930\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-904.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-889.8\" font-family=\"Times,serif\" font-size=\"14.00\">500</text>\n", | |
"</g>\n", | |
"<!-- fullyconnected0->flatten1 -->\n", | |
"<g id=\"edge8\" class=\"edge\"><title>fullyconnected0->flatten1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-861.582C47,-848.284 47,-833.632 47,-821.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-871.887 42.5001,-861.887 47,-866.887 47.0001,-861.887 47.0001,-861.887 47.0001,-861.887 47,-866.887 51.5001,-861.887 47,-871.887 47,-871.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-842.8\" font-family=\"Times,serif\" font-size=\"14.00\">800</text>\n", | |
"</g>\n", | |
"<!-- activation2 -->\n", | |
"<g id=\"node10\" class=\"node\"><title>activation2</title>\n", | |
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-1039 -7.10543e-15,-1039 -7.10543e-15,-981 94,-981 94,-1039\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-1013.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-998.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", | |
"</g>\n", | |
"<!-- activation2->fullyconnected0 -->\n", | |
"<g id=\"edge9\" class=\"edge\"><title>activation2->fullyconnected0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-970.582C47,-957.284 47,-942.632 47,-930.297\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-980.887 42.5001,-970.887 47,-975.887 47.0001,-970.887 47.0001,-970.887 47.0001,-970.887 47,-975.887 51.5001,-970.887 47,-980.887 47,-980.887\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-951.8\" font-family=\"Times,serif\" font-size=\"14.00\">500</text>\n", | |
"</g>\n", | |
"<!-- fullyconnected1 -->\n", | |
"<g id=\"node11\" class=\"node\"><title>fullyconnected1</title>\n", | |
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-1148 -7.10543e-15,-1148 -7.10543e-15,-1090 94,-1090 94,-1148\"/>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-1122.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n", | |
"<text text-anchor=\"middle\" x=\"47\" y=\"-1107.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- fullyconnected1->activation2 -->\n", | |
"<g id=\"edge10\" class=\"edge\"><title>fullyconnected1->activation2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M47,-1079.58C47,-1066.28 47,-1051.63 47,-1039.3\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-1089.89 42.5001,-1079.89 47,-1084.89 47.0001,-1079.89 47.0001,-1079.89 47.0001,-1079.89 47,-1084.89 51.5001,-1079.89 47,-1089.89 47,-1089.89\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-1060.8\" font-family=\"Times,serif\" font-size=\"14.00\">500</text>\n", | |
"</g>\n", | |
"<!-- softmax_label -->\n", | |
"<g id=\"node12\" class=\"node\"><title>softmax_label</title>\n", | |
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"159\" cy=\"-1119\" rx=\"47\" ry=\"29\"/>\n", | |
"<text text-anchor=\"middle\" x=\"159\" y=\"-1115.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax_label</text>\n", | |
"</g>\n", | |
"<!-- softmax -->\n", | |
"<g id=\"node13\" class=\"node\"><title>softmax</title>\n", | |
"<polygon fill=\"#fccde5\" stroke=\"black\" points=\"170,-1257 76,-1257 76,-1199 170,-1199 170,-1257\"/>\n", | |
"<text text-anchor=\"middle\" x=\"123\" y=\"-1224.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax</text>\n", | |
"</g>\n", | |
"<!-- softmax->fullyconnected1 -->\n", | |
"<g id=\"edge11\" class=\"edge\"><title>softmax->fullyconnected1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M97.1082,-1190.55C87.3017,-1176.74 76.2938,-1161.24 67.0986,-1148.3\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"103.032,-1198.89 93.5724,-1193.34 100.137,-1194.81 97.2411,-1190.73 97.2411,-1190.73 97.2411,-1190.73 100.137,-1194.81 100.91,-1188.13 103.032,-1198.89 103.032,-1198.89\"/>\n", | |
"<text text-anchor=\"middle\" x=\"97\" y=\"-1169.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- softmax->softmax_label -->\n", | |
"<g id=\"edge12\" class=\"edge\"><title>softmax->softmax_label</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M135.713,-1189.21C140.333,-1175.48 145.463,-1160.24 149.729,-1147.56\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"132.459,-1198.89 131.382,-1187.97 134.053,-1194.15 135.647,-1189.41 135.647,-1189.41 135.647,-1189.41 134.053,-1194.15 139.913,-1190.84 132.459,-1198.89 132.459,-1198.89\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.dot.Digraph at 0x7f9c0ef2e050>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data = mx.symbol.Variable('data')\n", | |
"# first conv layer\n", | |
"conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)\n", | |
"tanh1 = mx.sym.Activation(data=conv1, act_type=\"tanh\")\n", | |
"pool1 = mx.sym.Pooling(data=tanh1, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n", | |
"# second conv layer\n", | |
"conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)\n", | |
"tanh2 = mx.sym.Activation(data=conv2, act_type=\"tanh\")\n", | |
"pool2 = mx.sym.Pooling(data=tanh2, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n", | |
"# first fullc layer\n", | |
"flatten = mx.sym.Flatten(data=pool2)\n", | |
"fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)\n", | |
"tanh3 = mx.sym.Activation(data=fc1, act_type=\"tanh\")\n", | |
"# second fullc\n", | |
"fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)\n", | |
"# softmax loss\n", | |
"lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')\n", | |
"mx.viz.plot_network(symbol=lenet, shape=shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/lib/python2.7/dist-packages/ipykernel_launcher.py:5: DeprecationWarning: \u001b[91mmxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n", | |
" \"\"\"\n", | |
"INFO:root:Start training with [gpu(0)]\n", | |
"INFO:root:Epoch[0] Batch [200]\tSpeed: 22201.87 samples/sec\taccuracy=0.111250\n", | |
"INFO:root:Epoch[0] Batch [400]\tSpeed: 25587.48 samples/sec\taccuracy=0.112250\n", | |
"INFO:root:Epoch[0] Batch [600]\tSpeed: 25564.46 samples/sec\taccuracy=0.113400\n", | |
"INFO:root:Epoch[0] Resetting Data Iterator\n", | |
"INFO:root:Epoch[0] Time cost=2.476\n", | |
"INFO:root:Epoch[0] Validation-accuracy=0.113500\n", | |
"INFO:root:Epoch[1] Batch [200]\tSpeed: 25543.51 samples/sec\taccuracy=0.144900\n", | |
"INFO:root:Epoch[1] Batch [400]\tSpeed: 25507.08 samples/sec\taccuracy=0.779950\n", | |
"INFO:root:Epoch[1] Batch [600]\tSpeed: 25530.16 samples/sec\taccuracy=0.914950\n", | |
"INFO:root:Epoch[1] Resetting Data Iterator\n", | |
"INFO:root:Epoch[1] Time cost=2.358\n", | |
"INFO:root:Epoch[1] Validation-accuracy=0.937700\n", | |
"INFO:root:Epoch[2] Batch [200]\tSpeed: 25617.25 samples/sec\taccuracy=0.942600\n", | |
"INFO:root:Epoch[2] Batch [400]\tSpeed: 25514.86 samples/sec\taccuracy=0.956050\n", | |
"INFO:root:Epoch[2] Batch [600]\tSpeed: 25445.62 samples/sec\taccuracy=0.965300\n", | |
"INFO:root:Epoch[2] Resetting Data Iterator\n", | |
"INFO:root:Epoch[2] Time cost=2.361\n", | |
"INFO:root:Epoch[2] Validation-accuracy=0.966800\n", | |
"INFO:root:Epoch[3] Batch [200]\tSpeed: 25647.79 samples/sec\taccuracy=0.970050\n", | |
"INFO:root:Epoch[3] Batch [400]\tSpeed: 25591.42 samples/sec\taccuracy=0.974750\n", | |
"INFO:root:Epoch[3] Batch [600]\tSpeed: 25567.90 samples/sec\taccuracy=0.977100\n", | |
"INFO:root:Epoch[3] Resetting Data Iterator\n", | |
"INFO:root:Epoch[3] Time cost=2.354\n", | |
"INFO:root:Epoch[3] Validation-accuracy=0.977400\n", | |
"INFO:root:Epoch[4] Batch [200]\tSpeed: 25633.57 samples/sec\taccuracy=0.977800\n", | |
"INFO:root:Epoch[4] Batch [400]\tSpeed: 25569.72 samples/sec\taccuracy=0.981450\n", | |
"INFO:root:Epoch[4] Batch [600]\tSpeed: 25546.67 samples/sec\taccuracy=0.983600\n", | |
"INFO:root:Epoch[4] Resetting Data Iterator\n", | |
"INFO:root:Epoch[4] Time cost=2.353\n", | |
"INFO:root:Epoch[4] Validation-accuracy=0.983600\n", | |
"INFO:root:Epoch[5] Batch [200]\tSpeed: 25620.73 samples/sec\taccuracy=0.982650\n", | |
"INFO:root:Epoch[5] Batch [400]\tSpeed: 25531.73 samples/sec\taccuracy=0.985600\n", | |
"INFO:root:Epoch[5] Batch [600]\tSpeed: 25566.39 samples/sec\taccuracy=0.986900\n", | |
"INFO:root:Epoch[5] Resetting Data Iterator\n", | |
"INFO:root:Epoch[5] Time cost=2.354\n", | |
"INFO:root:Epoch[5] Validation-accuracy=0.985200\n", | |
"INFO:root:Epoch[6] Batch [200]\tSpeed: 25608.59 samples/sec\taccuracy=0.985700\n", | |
"INFO:root:Epoch[6] Batch [400]\tSpeed: 25532.58 samples/sec\taccuracy=0.987850\n", | |
"INFO:root:Epoch[6] Batch [600]\tSpeed: 25603.90 samples/sec\taccuracy=0.988800\n", | |
"INFO:root:Epoch[6] Resetting Data Iterator\n", | |
"INFO:root:Epoch[6] Time cost=2.353\n", | |
"INFO:root:Epoch[6] Validation-accuracy=0.986200\n", | |
"INFO:root:Epoch[7] Batch [200]\tSpeed: 25635.70 samples/sec\taccuracy=0.987600\n", | |
"INFO:root:Epoch[7] Batch [400]\tSpeed: 25556.92 samples/sec\taccuracy=0.990150\n", | |
"INFO:root:Epoch[7] Batch [600]\tSpeed: 25571.29 samples/sec\taccuracy=0.990400\n", | |
"INFO:root:Epoch[7] Resetting Data Iterator\n", | |
"INFO:root:Epoch[7] Time cost=2.353\n", | |
"INFO:root:Epoch[7] Validation-accuracy=0.987300\n", | |
"INFO:root:Epoch[8] Batch [200]\tSpeed: 25571.07 samples/sec\taccuracy=0.989600\n", | |
"INFO:root:Epoch[8] Batch [400]\tSpeed: 25529.62 samples/sec\taccuracy=0.991650\n", | |
"INFO:root:Epoch[8] Batch [600]\tSpeed: 25497.78 samples/sec\taccuracy=0.991800\n", | |
"INFO:root:Epoch[8] Resetting Data Iterator\n", | |
"INFO:root:Epoch[8] Time cost=2.358\n", | |
"INFO:root:Epoch[8] Validation-accuracy=0.987800\n", | |
"INFO:root:Epoch[9] Batch [200]\tSpeed: 25641.87 samples/sec\taccuracy=0.990700\n", | |
"INFO:root:Epoch[9] Batch [400]\tSpeed: 25539.56 samples/sec\taccuracy=0.992750\n", | |
"INFO:root:Epoch[9] Batch [600]\tSpeed: 25563.02 samples/sec\taccuracy=0.992850\n", | |
"INFO:root:Epoch[9] Resetting Data Iterator\n", | |
"INFO:root:Epoch[9] Time cost=2.356\n", | |
"INFO:root:Epoch[9] Validation-accuracy=0.988000\n" | |
] | |
} | |
], | |
"source": [ | |
"model = mx.model.FeedForward(\n", | |
" ctx = mx.gpu(0), # use GPU 0 for training, others are same as before\n", | |
" symbol = lenet, \n", | |
" num_epoch = 10, \n", | |
" learning_rate = 0.1)\n", | |
"model.fit(\n", | |
" X=train_iter, \n", | |
" eval_data=val_iter, \n", | |
" batch_end_callback = mx.callback.Speedometer(batch_size, 200)\n", | |
") \n", | |
"assert model.score(val_iter) > 0.98, \"Low validation accuracy.\"" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment