Skip to content

Instantly share code, notes, and snippets.

@regonn
Created October 30, 2017 13:24
Show Gist options
  • Save regonn/f7db4fe43559110f5e3b2fd65e70cc7f to your computer and use it in GitHub Desktop.
Save regonn/f7db4fe43559110f5e3b2fd65e70cc7f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AWS の Deep learning 用 AMI で MXNet を動かす\n",
"\n",
"## モチベーション\n",
"- ローカルのパソコンだと処理速度に限界を感じてきた\n",
"- かといって、自作PCを作るモチベーションはあまりない\n",
"- AWSなどのクラウドサーバーのインスタンスでできると良さそう\n",
"- Amazon は MXNet という機械学習ライブラリを公式にサポートしており、AWSとの相性は良い\n",
" - [MXNet とは \\| AWS](https://aws.amazon.com/jp/mxnet/)\n",
" \n",
"## 今回参考にした記事とコード\n",
"- [MXNetをAmazon Deep Learning AMIとスポットインスタンスのGPUで試してみる | Developers\\.IO](https://dev.classmethod.jp/cloud/aws/mxnet-on-amazon-deep-learning-ami-with-gpu-instance/)\n",
"- [MXNetのmnistチュートリアル dmlc/mxnet\\-notebooks](https://github.com/dmlc/mxnet-notebooks/blob/master/python/tutorials/mnist.ipynb)\n",
"\n",
"## インスタンスを構築する\n",
"AWS上で行う\n",
"\n",
"## sshでログイン\n",
"\n",
"```\n",
"$ ssh -i your-key.pem [email protected] -L 8888:localhost:8888\n",
"```\n",
"\n",
"`-L` オプションを付けることで、ローカルの8888番ポートからアクセスできるようになる。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
"import urllib\n",
"import gzip\n",
"import struct\n",
"\n",
"# mnist データをダウンロードしてくる\n",
"def download_data(url, force_download=True): \n",
" fname = url.split(\"/\")[-1]\n",
" if force_download or not os.path.exists(fname):\n",
" urllib.urlretrieve(url, fname)\n",
" return fname\n",
"\n",
"def read_data(label_url, image_url):\n",
" with gzip.open(download_data(label_url)) as flbl:\n",
" magic, num = struct.unpack(\">II\", flbl.read(8))\n",
" label = np.fromstring(flbl.read(), dtype=np.int8)\n",
" with gzip.open(download_data(image_url), 'rb') as fimg:\n",
" magic, num, rows, cols = struct.unpack(\">IIII\", fimg.read(16))\n",
" image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)\n",
" return (label, image)\n",
"\n",
"path='http://yann.lecun.com/exdb/mnist/'\n",
"(train_lbl, train_img) = read_data(\n",
" path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')\n",
"(val_lbl, val_img) = read_data(\n",
" path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import mxnet as mx\n",
"\n",
"def to4d(img):\n",
" return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255\n",
"\n",
"batch_size = 100\n",
"train_iter = mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)\n",
"val_iter = mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## まずは CPU で MLP(Multi Layer Perceptron:多層パーセプトロン)を使った学習"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: plot Pages: 1 -->\n",
"<svg width=\"214pt\" height=\"829pt\"\n",
" viewBox=\"0.00 0.00 214.00 829.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 825)\">\n",
"<title>plot</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-825 210,-825 210,4 -4,4\"/>\n",
"<!-- data -->\n",
"<g id=\"node1\" class=\"node\"><title>data</title>\n",
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"47\" cy=\"-29\" rx=\"47\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-25.3\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n",
"</g>\n",
"<!-- flatten0 -->\n",
"<g id=\"node2\" class=\"node\"><title>flatten0</title>\n",
"<polygon fill=\"#fdb462\" stroke=\"black\" points=\"94,-167 -7.10543e-15,-167 -7.10543e-15,-109 94,-109 94,-167\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-134.3\" font-family=\"Times,serif\" font-size=\"14.00\">flatten0</text>\n",
"</g>\n",
"<!-- flatten0&#45;&gt;data -->\n",
"<g id=\"edge1\" class=\"edge\"><title>flatten0&#45;&gt;data</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-98.5824C47,-85.2841 47,-70.632 47,-58.2967\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-108.887 42.5001,-98.887 47,-103.887 47.0001,-98.887 47.0001,-98.887 47.0001,-98.887 47,-103.887 51.5001,-98.8871 47,-108.887 47,-108.887\"/>\n",
"<text text-anchor=\"middle\" x=\"71\" y=\"-79.8\" font-family=\"Times,serif\" font-size=\"14.00\">1x28x28</text>\n",
"</g>\n",
"<!-- fc1 -->\n",
"<g id=\"node3\" class=\"node\"><title>fc1</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-276 -7.10543e-15,-276 -7.10543e-15,-218 94,-218 94,-276\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-250.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-235.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- fc1&#45;&gt;flatten0 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>fc1&#45;&gt;flatten0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-207.582C47,-194.284 47,-179.632 47,-167.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-217.887 42.5001,-207.887 47,-212.887 47.0001,-207.887 47.0001,-207.887 47.0001,-207.887 47,-212.887 51.5001,-207.887 47,-217.887 47,-217.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-188.8\" font-family=\"Times,serif\" font-size=\"14.00\">784</text>\n",
"</g>\n",
"<!-- relu1 -->\n",
"<g id=\"node4\" class=\"node\"><title>relu1</title>\n",
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-385 -7.10543e-15,-385 -7.10543e-15,-327 94,-327 94,-385\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-359.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-344.8\" font-family=\"Times,serif\" font-size=\"14.00\">relu</text>\n",
"</g>\n",
"<!-- relu1&#45;&gt;fc1 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>relu1&#45;&gt;fc1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-316.582C47,-303.284 47,-288.632 47,-276.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-326.887 42.5001,-316.887 47,-321.887 47.0001,-316.887 47.0001,-316.887 47.0001,-316.887 47,-321.887 51.5001,-316.887 47,-326.887 47,-326.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-297.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- fc2 -->\n",
"<g id=\"node5\" class=\"node\"><title>fc2</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-494 -7.10543e-15,-494 -7.10543e-15,-436 94,-436 94,-494\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-468.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-453.8\" font-family=\"Times,serif\" font-size=\"14.00\">64</text>\n",
"</g>\n",
"<!-- fc2&#45;&gt;relu1 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>fc2&#45;&gt;relu1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-425.582C47,-412.284 47,-397.632 47,-385.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-435.887 42.5001,-425.887 47,-430.887 47.0001,-425.887 47.0001,-425.887 47.0001,-425.887 47,-430.887 51.5001,-425.887 47,-435.887 47,-435.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-406.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- relu2 -->\n",
"<g id=\"node6\" class=\"node\"><title>relu2</title>\n",
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-603 -7.10543e-15,-603 -7.10543e-15,-545 94,-545 94,-603\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-577.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-562.8\" font-family=\"Times,serif\" font-size=\"14.00\">relu</text>\n",
"</g>\n",
"<!-- relu2&#45;&gt;fc2 -->\n",
"<g id=\"edge5\" class=\"edge\"><title>relu2&#45;&gt;fc2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-534.582C47,-521.284 47,-506.632 47,-494.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-544.887 42.5001,-534.887 47,-539.887 47.0001,-534.887 47.0001,-534.887 47.0001,-534.887 47,-539.887 51.5001,-534.887 47,-544.887 47,-544.887\"/>\n",
"<text text-anchor=\"middle\" x=\"54\" y=\"-515.8\" font-family=\"Times,serif\" font-size=\"14.00\">64</text>\n",
"</g>\n",
"<!-- fc3 -->\n",
"<g id=\"node7\" class=\"node\"><title>fc3</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-712 -7.10543e-15,-712 -7.10543e-15,-654 94,-654 94,-712\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-686.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-671.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- fc3&#45;&gt;relu2 -->\n",
"<g id=\"edge6\" class=\"edge\"><title>fc3&#45;&gt;relu2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-643.582C47,-630.284 47,-615.632 47,-603.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-653.887 42.5001,-643.887 47,-648.887 47.0001,-643.887 47.0001,-643.887 47.0001,-643.887 47,-648.887 51.5001,-643.887 47,-653.887 47,-653.887\"/>\n",
"<text text-anchor=\"middle\" x=\"54\" y=\"-624.8\" font-family=\"Times,serif\" font-size=\"14.00\">64</text>\n",
"</g>\n",
"<!-- softmax_label -->\n",
"<g id=\"node8\" class=\"node\"><title>softmax_label</title>\n",
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"159\" cy=\"-683\" rx=\"47\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"159\" y=\"-679.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax_label</text>\n",
"</g>\n",
"<!-- softmax -->\n",
"<g id=\"node9\" class=\"node\"><title>softmax</title>\n",
"<polygon fill=\"#fccde5\" stroke=\"black\" points=\"170,-821 76,-821 76,-763 170,-763 170,-821\"/>\n",
"<text text-anchor=\"middle\" x=\"123\" y=\"-788.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax</text>\n",
"</g>\n",
"<!-- softmax&#45;&gt;fc3 -->\n",
"<g id=\"edge7\" class=\"edge\"><title>softmax&#45;&gt;fc3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M97.1082,-754.547C87.3017,-740.741 76.2938,-725.243 67.0986,-712.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"103.032,-762.887 93.5724,-757.34 100.137,-758.811 97.2411,-754.734 97.2411,-754.734 97.2411,-754.734 100.137,-758.811 100.91,-752.128 103.032,-762.887 103.032,-762.887\"/>\n",
"<text text-anchor=\"middle\" x=\"97\" y=\"-733.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- softmax&#45;&gt;softmax_label -->\n",
"<g id=\"edge8\" class=\"edge\"><title>softmax&#45;&gt;softmax_label</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M135.713,-753.215C140.333,-739.483 145.463,-724.236 149.729,-711.555\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"132.459,-762.887 131.382,-751.974 134.053,-758.148 135.647,-753.409 135.647,-753.409 135.647,-753.409 134.053,-758.148 139.913,-754.844 132.459,-762.887 132.459,-762.887\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x7f9c0fe18690>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a place holder variable for the input data\n",
"data = mx.sym.Variable('data')\n",
"data = mx.sym.Flatten(data=data)\n",
"\n",
"# The first fully-connected layer\n",
"fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)\n",
"# Apply relu to the output of the first fully-connnected layer\n",
"act1 = mx.sym.Activation(data=fc1, name='relu1', act_type=\"relu\")\n",
"\n",
"# The second fully-connected layer and the according activation function\n",
"fc2 = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden = 64)\n",
"act2 = mx.sym.Activation(data=fc2, name='relu2', act_type=\"relu\")\n",
"\n",
"# The thrid fully-connected layer, note that the hidden size should be 10, which is the number of unique digits\n",
"fc3 = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)\n",
"# The softmax and loss layer\n",
"mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')\n",
"\n",
"# We visualize the network structure with output size (the batch_size is ignored.)\n",
"shape = {\"data\" : (batch_size, 1, 28, 28)}\n",
"mx.viz.plot_network(symbol=mlp, shape=shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python2.7/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: \u001b[91mmxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n",
" import sys\n",
"/usr/lib/python2.7/dist-packages/mxnet-0.11.0-py2.7.egg/mxnet/model.py:547: DeprecationWarning: \u001b[91mCalling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.\u001b[0m\n",
" self.initializer(k, v)\n",
"INFO:root:Start training with [cpu(0)]\n",
"INFO:root:Epoch[0] Batch [200]\tSpeed: 39874.55 samples/sec\taccuracy=0.110900\n",
"INFO:root:Epoch[0] Batch [400]\tSpeed: 38556.08 samples/sec\taccuracy=0.112250\n",
"INFO:root:Epoch[0] Batch [600]\tSpeed: 40741.83 samples/sec\taccuracy=0.138200\n",
"INFO:root:Epoch[0] Resetting Data Iterator\n",
"INFO:root:Epoch[0] Time cost=1.530\n",
"INFO:root:Epoch[0] Validation-accuracy=0.217600\n",
"INFO:root:Epoch[1] Batch [200]\tSpeed: 39664.76 samples/sec\taccuracy=0.419750\n",
"INFO:root:Epoch[1] Batch [400]\tSpeed: 37210.13 samples/sec\taccuracy=0.747450\n",
"INFO:root:Epoch[1] Batch [600]\tSpeed: 38671.70 samples/sec\taccuracy=0.829000\n",
"INFO:root:Epoch[1] Resetting Data Iterator\n",
"INFO:root:Epoch[1] Time cost=1.569\n",
"INFO:root:Epoch[1] Validation-accuracy=0.857800\n",
"INFO:root:Epoch[2] Batch [200]\tSpeed: 37269.50 samples/sec\taccuracy=0.861850\n",
"INFO:root:Epoch[2] Batch [400]\tSpeed: 39785.30 samples/sec\taccuracy=0.888600\n",
"INFO:root:Epoch[2] Batch [600]\tSpeed: 38696.32 samples/sec\taccuracy=0.904550\n",
"INFO:root:Epoch[2] Resetting Data Iterator\n",
"INFO:root:Epoch[2] Time cost=1.566\n",
"INFO:root:Epoch[2] Validation-accuracy=0.916000\n",
"INFO:root:Epoch[3] Batch [200]\tSpeed: 37400.46 samples/sec\taccuracy=0.917950\n",
"INFO:root:Epoch[3] Batch [400]\tSpeed: 37174.52 samples/sec\taccuracy=0.930800\n",
"INFO:root:Epoch[3] Batch [600]\tSpeed: 36195.49 samples/sec\taccuracy=0.935700\n",
"INFO:root:Epoch[3] Resetting Data Iterator\n",
"INFO:root:Epoch[3] Time cost=1.635\n",
"INFO:root:Epoch[3] Validation-accuracy=0.940000\n",
"INFO:root:Epoch[4] Batch [200]\tSpeed: 36301.26 samples/sec\taccuracy=0.940950\n",
"INFO:root:Epoch[4] Batch [400]\tSpeed: 31858.09 samples/sec\taccuracy=0.949800\n",
"INFO:root:Epoch[4] Batch [600]\tSpeed: 36146.50 samples/sec\taccuracy=0.949650\n",
"INFO:root:Epoch[4] Resetting Data Iterator\n",
"INFO:root:Epoch[4] Time cost=1.742\n",
"INFO:root:Epoch[4] Validation-accuracy=0.951600\n",
"INFO:root:Epoch[5] Batch [200]\tSpeed: 36306.79 samples/sec\taccuracy=0.952900\n",
"INFO:root:Epoch[5] Batch [400]\tSpeed: 37419.27 samples/sec\taccuracy=0.959650\n",
"INFO:root:Epoch[5] Batch [600]\tSpeed: 36265.92 samples/sec\taccuracy=0.959500\n",
"INFO:root:Epoch[5] Resetting Data Iterator\n",
"INFO:root:Epoch[5] Time cost=1.647\n",
"INFO:root:Epoch[5] Validation-accuracy=0.956200\n",
"INFO:root:Epoch[6] Batch [200]\tSpeed: 36383.02 samples/sec\taccuracy=0.961200\n",
"INFO:root:Epoch[6] Batch [400]\tSpeed: 35857.33 samples/sec\taccuracy=0.966100\n",
"INFO:root:Epoch[6] Batch [600]\tSpeed: 36210.63 samples/sec\taccuracy=0.965350\n",
"INFO:root:Epoch[6] Resetting Data Iterator\n",
"INFO:root:Epoch[6] Time cost=1.670\n",
"INFO:root:Epoch[6] Validation-accuracy=0.957700\n",
"INFO:root:Epoch[7] Batch [200]\tSpeed: 36412.97 samples/sec\taccuracy=0.966700\n",
"INFO:root:Epoch[7] Batch [400]\tSpeed: 36228.08 samples/sec\taccuracy=0.970800\n",
"INFO:root:Epoch[7] Batch [600]\tSpeed: 36257.63 samples/sec\taccuracy=0.969250\n",
"INFO:root:Epoch[7] Resetting Data Iterator\n",
"INFO:root:Epoch[7] Time cost=1.663\n",
"INFO:root:Epoch[7] Validation-accuracy=0.960400\n",
"INFO:root:Epoch[8] Batch [200]\tSpeed: 36243.76 samples/sec\taccuracy=0.971000\n",
"INFO:root:Epoch[8] Batch [400]\tSpeed: 36209.45 samples/sec\taccuracy=0.974450\n",
"INFO:root:Epoch[8] Batch [600]\tSpeed: 36219.94 samples/sec\taccuracy=0.973150\n",
"INFO:root:Epoch[8] Resetting Data Iterator\n",
"INFO:root:Epoch[8] Time cost=1.666\n",
"INFO:root:Epoch[8] Validation-accuracy=0.961900\n",
"INFO:root:Epoch[9] Batch [200]\tSpeed: 36498.22 samples/sec\taccuracy=0.974500\n",
"INFO:root:Epoch[9] Batch [400]\tSpeed: 37476.41 samples/sec\taccuracy=0.977200\n",
"INFO:root:Epoch[9] Batch [600]\tSpeed: 36123.13 samples/sec\taccuracy=0.976100\n",
"INFO:root:Epoch[9] Resetting Data Iterator\n",
"INFO:root:Epoch[9] Time cost=1.645\n",
"INFO:root:Epoch[9] Validation-accuracy=0.963100\n"
]
}
],
"source": [
"import logging\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"model = mx.model.FeedForward(\n",
" symbol = mlp, # network structure\n",
" num_epoch = 10, # number of data passes for training \n",
" learning_rate = 0.1 # learning rate of SGD \n",
")\n",
"\n",
"model.fit(\n",
" X=train_iter, # training data\n",
" eval_data=val_iter, # validation data\n",
" batch_end_callback = mx.callback.Speedometer(batch_size, 200) # output progress for each 200 data batches\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 続いて CNN の LeNet でGPUを使って計算"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CNNとLeNetについては下の2つを読むと理解しやすい\n",
"- [\\[ディープラーニング\\] LeNet – Tech Memo](http://tecmemo.wpblog.jp/2017/03/19/dl_lenet/)\n",
"- [定番のConvolutional Neural Networkをゼロから理解する \\- DeepAge](https://deepage.net/deep_learning/2016/11/07/convolutional_neural_network.html)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: plot Pages: 1 -->\n",
"<svg width=\"214pt\" height=\"1265pt\"\n",
" viewBox=\"0.00 0.00 214.00 1265.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 1261)\">\n",
"<title>plot</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-1261 210,-1261 210,4 -4,4\"/>\n",
"<!-- data -->\n",
"<g id=\"node1\" class=\"node\"><title>data</title>\n",
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"47\" cy=\"-29\" rx=\"47\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-25.3\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n",
"</g>\n",
"<!-- convolution0 -->\n",
"<g id=\"node2\" class=\"node\"><title>convolution0</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-167 -7.10543e-15,-167 -7.10543e-15,-109 94,-109 94,-167\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-141.8\" font-family=\"Times,serif\" font-size=\"14.00\">Convolution</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-126.8\" font-family=\"Times,serif\" font-size=\"14.00\">5x5/1, 20</text>\n",
"</g>\n",
"<!-- convolution0&#45;&gt;data -->\n",
"<g id=\"edge1\" class=\"edge\"><title>convolution0&#45;&gt;data</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-98.5824C47,-85.2841 47,-70.632 47,-58.2967\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-108.887 42.5001,-98.887 47,-103.887 47.0001,-98.887 47.0001,-98.887 47.0001,-98.887 47,-103.887 51.5001,-98.8871 47,-108.887 47,-108.887\"/>\n",
"<text text-anchor=\"middle\" x=\"71\" y=\"-79.8\" font-family=\"Times,serif\" font-size=\"14.00\">1x28x28</text>\n",
"</g>\n",
"<!-- activation0 -->\n",
"<g id=\"node3\" class=\"node\"><title>activation0</title>\n",
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-276 -7.10543e-15,-276 -7.10543e-15,-218 94,-218 94,-276\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-250.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-235.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- activation0&#45;&gt;convolution0 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>activation0&#45;&gt;convolution0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-207.582C47,-194.284 47,-179.632 47,-167.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-217.887 42.5001,-207.887 47,-212.887 47.0001,-207.887 47.0001,-207.887 47.0001,-207.887 47,-212.887 51.5001,-207.887 47,-217.887 47,-217.887\"/>\n",
"<text text-anchor=\"middle\" x=\"74.5\" y=\"-188.8\" font-family=\"Times,serif\" font-size=\"14.00\">20x24x24</text>\n",
"</g>\n",
"<!-- pooling0 -->\n",
"<g id=\"node4\" class=\"node\"><title>pooling0</title>\n",
"<polygon fill=\"#80b1d3\" stroke=\"black\" points=\"94,-385 -7.10543e-15,-385 -7.10543e-15,-327 94,-327 94,-385\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-359.8\" font-family=\"Times,serif\" font-size=\"14.00\">Pooling</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-344.8\" font-family=\"Times,serif\" font-size=\"14.00\">max, 2x2/2x2</text>\n",
"</g>\n",
"<!-- pooling0&#45;&gt;activation0 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>pooling0&#45;&gt;activation0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-316.582C47,-303.284 47,-288.632 47,-276.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-326.887 42.5001,-316.887 47,-321.887 47.0001,-316.887 47.0001,-316.887 47.0001,-316.887 47,-321.887 51.5001,-316.887 47,-326.887 47,-326.887\"/>\n",
"<text text-anchor=\"middle\" x=\"74.5\" y=\"-297.8\" font-family=\"Times,serif\" font-size=\"14.00\">20x24x24</text>\n",
"</g>\n",
"<!-- convolution1 -->\n",
"<g id=\"node5\" class=\"node\"><title>convolution1</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-494 -7.10543e-15,-494 -7.10543e-15,-436 94,-436 94,-494\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-468.8\" font-family=\"Times,serif\" font-size=\"14.00\">Convolution</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-453.8\" font-family=\"Times,serif\" font-size=\"14.00\">5x5/1, 50</text>\n",
"</g>\n",
"<!-- convolution1&#45;&gt;pooling0 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>convolution1&#45;&gt;pooling0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-425.582C47,-412.284 47,-397.632 47,-385.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-435.887 42.5001,-425.887 47,-430.887 47.0001,-425.887 47.0001,-425.887 47.0001,-425.887 47,-430.887 51.5001,-425.887 47,-435.887 47,-435.887\"/>\n",
"<text text-anchor=\"middle\" x=\"74.5\" y=\"-406.8\" font-family=\"Times,serif\" font-size=\"14.00\">20x12x12</text>\n",
"</g>\n",
"<!-- activation1 -->\n",
"<g id=\"node6\" class=\"node\"><title>activation1</title>\n",
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-603 -7.10543e-15,-603 -7.10543e-15,-545 94,-545 94,-603\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-577.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-562.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- activation1&#45;&gt;convolution1 -->\n",
"<g id=\"edge5\" class=\"edge\"><title>activation1&#45;&gt;convolution1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-534.582C47,-521.284 47,-506.632 47,-494.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-544.887 42.5001,-534.887 47,-539.887 47.0001,-534.887 47.0001,-534.887 47.0001,-534.887 47,-539.887 51.5001,-534.887 47,-544.887 47,-544.887\"/>\n",
"<text text-anchor=\"middle\" x=\"67.5\" y=\"-515.8\" font-family=\"Times,serif\" font-size=\"14.00\">50x8x8</text>\n",
"</g>\n",
"<!-- pooling1 -->\n",
"<g id=\"node7\" class=\"node\"><title>pooling1</title>\n",
"<polygon fill=\"#80b1d3\" stroke=\"black\" points=\"94,-712 -7.10543e-15,-712 -7.10543e-15,-654 94,-654 94,-712\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-686.8\" font-family=\"Times,serif\" font-size=\"14.00\">Pooling</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-671.8\" font-family=\"Times,serif\" font-size=\"14.00\">max, 2x2/2x2</text>\n",
"</g>\n",
"<!-- pooling1&#45;&gt;activation1 -->\n",
"<g id=\"edge6\" class=\"edge\"><title>pooling1&#45;&gt;activation1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-643.582C47,-630.284 47,-615.632 47,-603.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-653.887 42.5001,-643.887 47,-648.887 47.0001,-643.887 47.0001,-643.887 47.0001,-643.887 47,-648.887 51.5001,-643.887 47,-653.887 47,-653.887\"/>\n",
"<text text-anchor=\"middle\" x=\"67.5\" y=\"-624.8\" font-family=\"Times,serif\" font-size=\"14.00\">50x8x8</text>\n",
"</g>\n",
"<!-- flatten1 -->\n",
"<g id=\"node8\" class=\"node\"><title>flatten1</title>\n",
"<polygon fill=\"#fdb462\" stroke=\"black\" points=\"94,-821 -7.10543e-15,-821 -7.10543e-15,-763 94,-763 94,-821\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-788.3\" font-family=\"Times,serif\" font-size=\"14.00\">flatten1</text>\n",
"</g>\n",
"<!-- flatten1&#45;&gt;pooling1 -->\n",
"<g id=\"edge7\" class=\"edge\"><title>flatten1&#45;&gt;pooling1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-752.582C47,-739.284 47,-724.632 47,-712.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-762.887 42.5001,-752.887 47,-757.887 47.0001,-752.887 47.0001,-752.887 47.0001,-752.887 47,-757.887 51.5001,-752.887 47,-762.887 47,-762.887\"/>\n",
"<text text-anchor=\"middle\" x=\"67.5\" y=\"-733.8\" font-family=\"Times,serif\" font-size=\"14.00\">50x4x4</text>\n",
"</g>\n",
"<!-- fullyconnected0 -->\n",
"<g id=\"node9\" class=\"node\"><title>fullyconnected0</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-930 -7.10543e-15,-930 -7.10543e-15,-872 94,-872 94,-930\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-904.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-889.8\" font-family=\"Times,serif\" font-size=\"14.00\">500</text>\n",
"</g>\n",
"<!-- fullyconnected0&#45;&gt;flatten1 -->\n",
"<g id=\"edge8\" class=\"edge\"><title>fullyconnected0&#45;&gt;flatten1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-861.582C47,-848.284 47,-833.632 47,-821.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-871.887 42.5001,-861.887 47,-866.887 47.0001,-861.887 47.0001,-861.887 47.0001,-861.887 47,-866.887 51.5001,-861.887 47,-871.887 47,-871.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-842.8\" font-family=\"Times,serif\" font-size=\"14.00\">800</text>\n",
"</g>\n",
"<!-- activation2 -->\n",
"<g id=\"node10\" class=\"node\"><title>activation2</title>\n",
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-1039 -7.10543e-15,-1039 -7.10543e-15,-981 94,-981 94,-1039\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-1013.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-998.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- activation2&#45;&gt;fullyconnected0 -->\n",
"<g id=\"edge9\" class=\"edge\"><title>activation2&#45;&gt;fullyconnected0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-970.582C47,-957.284 47,-942.632 47,-930.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-980.887 42.5001,-970.887 47,-975.887 47.0001,-970.887 47.0001,-970.887 47.0001,-970.887 47,-975.887 51.5001,-970.887 47,-980.887 47,-980.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-951.8\" font-family=\"Times,serif\" font-size=\"14.00\">500</text>\n",
"</g>\n",
"<!-- fullyconnected1 -->\n",
"<g id=\"node11\" class=\"node\"><title>fullyconnected1</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-1148 -7.10543e-15,-1148 -7.10543e-15,-1090 94,-1090 94,-1148\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-1122.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-1107.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- fullyconnected1&#45;&gt;activation2 -->\n",
"<g id=\"edge10\" class=\"edge\"><title>fullyconnected1&#45;&gt;activation2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-1079.58C47,-1066.28 47,-1051.63 47,-1039.3\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-1089.89 42.5001,-1079.89 47,-1084.89 47.0001,-1079.89 47.0001,-1079.89 47.0001,-1079.89 47,-1084.89 51.5001,-1079.89 47,-1089.89 47,-1089.89\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-1060.8\" font-family=\"Times,serif\" font-size=\"14.00\">500</text>\n",
"</g>\n",
"<!-- softmax_label -->\n",
"<g id=\"node12\" class=\"node\"><title>softmax_label</title>\n",
"<ellipse fill=\"#8dd3c7\" stroke=\"black\" cx=\"159\" cy=\"-1119\" rx=\"47\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"159\" y=\"-1115.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax_label</text>\n",
"</g>\n",
"<!-- softmax -->\n",
"<g id=\"node13\" class=\"node\"><title>softmax</title>\n",
"<polygon fill=\"#fccde5\" stroke=\"black\" points=\"170,-1257 76,-1257 76,-1199 170,-1199 170,-1257\"/>\n",
"<text text-anchor=\"middle\" x=\"123\" y=\"-1224.3\" font-family=\"Times,serif\" font-size=\"14.00\">softmax</text>\n",
"</g>\n",
"<!-- softmax&#45;&gt;fullyconnected1 -->\n",
"<g id=\"edge11\" class=\"edge\"><title>softmax&#45;&gt;fullyconnected1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M97.1082,-1190.55C87.3017,-1176.74 76.2938,-1161.24 67.0986,-1148.3\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"103.032,-1198.89 93.5724,-1193.34 100.137,-1194.81 97.2411,-1190.73 97.2411,-1190.73 97.2411,-1190.73 100.137,-1194.81 100.91,-1188.13 103.032,-1198.89 103.032,-1198.89\"/>\n",
"<text text-anchor=\"middle\" x=\"97\" y=\"-1169.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- softmax&#45;&gt;softmax_label -->\n",
"<g id=\"edge12\" class=\"edge\"><title>softmax&#45;&gt;softmax_label</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M135.713,-1189.21C140.333,-1175.48 145.463,-1160.24 149.729,-1147.56\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"132.459,-1198.89 131.382,-1187.97 134.053,-1194.15 135.647,-1189.41 135.647,-1189.41 135.647,-1189.41 134.053,-1194.15 139.913,-1190.84 132.459,-1198.89 132.459,-1198.89\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x7f9c0ef2e050>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = mx.symbol.Variable('data')\n",
"# first conv layer\n",
"conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)\n",
"tanh1 = mx.sym.Activation(data=conv1, act_type=\"tanh\")\n",
"pool1 = mx.sym.Pooling(data=tanh1, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n",
"# second conv layer\n",
"conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)\n",
"tanh2 = mx.sym.Activation(data=conv2, act_type=\"tanh\")\n",
"pool2 = mx.sym.Pooling(data=tanh2, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n",
"# first fullc layer\n",
"flatten = mx.sym.Flatten(data=pool2)\n",
"fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)\n",
"tanh3 = mx.sym.Activation(data=fc1, act_type=\"tanh\")\n",
"# second fullc\n",
"fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)\n",
"# softmax loss\n",
"lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')\n",
"mx.viz.plot_network(symbol=lenet, shape=shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python2.7/dist-packages/ipykernel_launcher.py:5: DeprecationWarning: \u001b[91mmxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n",
" \"\"\"\n",
"INFO:root:Start training with [gpu(0)]\n",
"INFO:root:Epoch[0] Batch [200]\tSpeed: 22201.87 samples/sec\taccuracy=0.111250\n",
"INFO:root:Epoch[0] Batch [400]\tSpeed: 25587.48 samples/sec\taccuracy=0.112250\n",
"INFO:root:Epoch[0] Batch [600]\tSpeed: 25564.46 samples/sec\taccuracy=0.113400\n",
"INFO:root:Epoch[0] Resetting Data Iterator\n",
"INFO:root:Epoch[0] Time cost=2.476\n",
"INFO:root:Epoch[0] Validation-accuracy=0.113500\n",
"INFO:root:Epoch[1] Batch [200]\tSpeed: 25543.51 samples/sec\taccuracy=0.144900\n",
"INFO:root:Epoch[1] Batch [400]\tSpeed: 25507.08 samples/sec\taccuracy=0.779950\n",
"INFO:root:Epoch[1] Batch [600]\tSpeed: 25530.16 samples/sec\taccuracy=0.914950\n",
"INFO:root:Epoch[1] Resetting Data Iterator\n",
"INFO:root:Epoch[1] Time cost=2.358\n",
"INFO:root:Epoch[1] Validation-accuracy=0.937700\n",
"INFO:root:Epoch[2] Batch [200]\tSpeed: 25617.25 samples/sec\taccuracy=0.942600\n",
"INFO:root:Epoch[2] Batch [400]\tSpeed: 25514.86 samples/sec\taccuracy=0.956050\n",
"INFO:root:Epoch[2] Batch [600]\tSpeed: 25445.62 samples/sec\taccuracy=0.965300\n",
"INFO:root:Epoch[2] Resetting Data Iterator\n",
"INFO:root:Epoch[2] Time cost=2.361\n",
"INFO:root:Epoch[2] Validation-accuracy=0.966800\n",
"INFO:root:Epoch[3] Batch [200]\tSpeed: 25647.79 samples/sec\taccuracy=0.970050\n",
"INFO:root:Epoch[3] Batch [400]\tSpeed: 25591.42 samples/sec\taccuracy=0.974750\n",
"INFO:root:Epoch[3] Batch [600]\tSpeed: 25567.90 samples/sec\taccuracy=0.977100\n",
"INFO:root:Epoch[3] Resetting Data Iterator\n",
"INFO:root:Epoch[3] Time cost=2.354\n",
"INFO:root:Epoch[3] Validation-accuracy=0.977400\n",
"INFO:root:Epoch[4] Batch [200]\tSpeed: 25633.57 samples/sec\taccuracy=0.977800\n",
"INFO:root:Epoch[4] Batch [400]\tSpeed: 25569.72 samples/sec\taccuracy=0.981450\n",
"INFO:root:Epoch[4] Batch [600]\tSpeed: 25546.67 samples/sec\taccuracy=0.983600\n",
"INFO:root:Epoch[4] Resetting Data Iterator\n",
"INFO:root:Epoch[4] Time cost=2.353\n",
"INFO:root:Epoch[4] Validation-accuracy=0.983600\n",
"INFO:root:Epoch[5] Batch [200]\tSpeed: 25620.73 samples/sec\taccuracy=0.982650\n",
"INFO:root:Epoch[5] Batch [400]\tSpeed: 25531.73 samples/sec\taccuracy=0.985600\n",
"INFO:root:Epoch[5] Batch [600]\tSpeed: 25566.39 samples/sec\taccuracy=0.986900\n",
"INFO:root:Epoch[5] Resetting Data Iterator\n",
"INFO:root:Epoch[5] Time cost=2.354\n",
"INFO:root:Epoch[5] Validation-accuracy=0.985200\n",
"INFO:root:Epoch[6] Batch [200]\tSpeed: 25608.59 samples/sec\taccuracy=0.985700\n",
"INFO:root:Epoch[6] Batch [400]\tSpeed: 25532.58 samples/sec\taccuracy=0.987850\n",
"INFO:root:Epoch[6] Batch [600]\tSpeed: 25603.90 samples/sec\taccuracy=0.988800\n",
"INFO:root:Epoch[6] Resetting Data Iterator\n",
"INFO:root:Epoch[6] Time cost=2.353\n",
"INFO:root:Epoch[6] Validation-accuracy=0.986200\n",
"INFO:root:Epoch[7] Batch [200]\tSpeed: 25635.70 samples/sec\taccuracy=0.987600\n",
"INFO:root:Epoch[7] Batch [400]\tSpeed: 25556.92 samples/sec\taccuracy=0.990150\n",
"INFO:root:Epoch[7] Batch [600]\tSpeed: 25571.29 samples/sec\taccuracy=0.990400\n",
"INFO:root:Epoch[7] Resetting Data Iterator\n",
"INFO:root:Epoch[7] Time cost=2.353\n",
"INFO:root:Epoch[7] Validation-accuracy=0.987300\n",
"INFO:root:Epoch[8] Batch [200]\tSpeed: 25571.07 samples/sec\taccuracy=0.989600\n",
"INFO:root:Epoch[8] Batch [400]\tSpeed: 25529.62 samples/sec\taccuracy=0.991650\n",
"INFO:root:Epoch[8] Batch [600]\tSpeed: 25497.78 samples/sec\taccuracy=0.991800\n",
"INFO:root:Epoch[8] Resetting Data Iterator\n",
"INFO:root:Epoch[8] Time cost=2.358\n",
"INFO:root:Epoch[8] Validation-accuracy=0.987800\n",
"INFO:root:Epoch[9] Batch [200]\tSpeed: 25641.87 samples/sec\taccuracy=0.990700\n",
"INFO:root:Epoch[9] Batch [400]\tSpeed: 25539.56 samples/sec\taccuracy=0.992750\n",
"INFO:root:Epoch[9] Batch [600]\tSpeed: 25563.02 samples/sec\taccuracy=0.992850\n",
"INFO:root:Epoch[9] Resetting Data Iterator\n",
"INFO:root:Epoch[9] Time cost=2.356\n",
"INFO:root:Epoch[9] Validation-accuracy=0.988000\n"
]
}
],
"source": [
"model = mx.model.FeedForward(\n",
" ctx = mx.gpu(0), # use GPU 0 for training, others are same as before\n",
" symbol = lenet, \n",
" num_epoch = 10, \n",
" learning_rate = 0.1)\n",
"model.fit(\n",
" X=train_iter, \n",
" eval_data=val_iter, \n",
" batch_end_callback = mx.callback.Speedometer(batch_size, 200)\n",
") \n",
"assert model.score(val_iter) > 0.98, \"Low validation accuracy.\""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment