Last active
November 28, 2017 15:00
-
-
Save nulledge/68a6b9a27a2e8c140b3cb1b01caa2df2 to your computer and use it in GitHub Desktop.
MNIST toy project run on Quadro m6000.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/nulledge/pyenv/tf-36/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n", | |
" return f(*args, **kwds)\n" | |
] | |
} | |
], | |
"source": [ | |
"'''Load required modules.\n", | |
"Modules:\n", | |
" tqdm: A visualizing tool for loop.\n", | |
" tensorflow: A framework for machine learning.\n", | |
" numpy: An array utility.\n", | |
"'''\n", | |
"from tqdm import tqdm as tqdm\n", | |
"import tensorflow as tf\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting /home/nulledge/data/MNIST/train-images-idx3-ubyte.gz\n", | |
"Extracting /home/nulledge/data/MNIST/train-labels-idx1-ubyte.gz\n", | |
"Extracting /home/nulledge/data/MNIST/t10k-images-idx3-ubyte.gz\n", | |
"Extracting /home/nulledge/data/MNIST/t10k-labels-idx1-ubyte.gz\n" | |
] | |
} | |
], | |
"source": [ | |
"'''Trainig data from TensorFlow.\n", | |
"'''\n", | |
"from tensorflow.examples.tutorials.mnist import input_data as mnist_data\n", | |
"mnist = mnist_data.read_data_sets(\"/home/nulledge/data/MNIST/\", one_hot=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"'''Define flags.\n", | |
"'''\n", | |
"flags = tf.app.flags\n", | |
"flags.DEFINE_integer('epoches', 10, 'The number of train epoches.')\n", | |
"flags.DEFINE_integer('mnist_train', 10000, 'The number of train dataset.')\n", | |
"flags.DEFINE_integer('mnist_test', 5000, 'The number of test dataset.')\n", | |
"flags.DEFINE_integer('batch', 25, 'The batch size.')\n", | |
"flags.DEFINE_integer('labels', 10, 'The number of labels.')\n", | |
"flags.DEFINE_integer('resolution', 28*28, 'The resolution of input image in flatten shape.')\n", | |
"flags.DEFINE_boolean('phase', False, 'Whether train mode or not.')\n", | |
"flags.DEFINE_string('checkpoint', '/home/nulledge/ckpt/MNIST/mnist.ckpt', 'The path of checkpoint.')\n", | |
"flags.DEFINE_string('summary', '/home/nulledge/log/MNIST/', 'The path of log.')\n", | |
"\n", | |
"FLAGS = flags.FLAGS" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"'''Placeholders\n", | |
"\n", | |
"Tensors:\n", | |
" images: The input images of MNIST in flatten shape.\n", | |
" labels_gt: The groundtruth labels of MNIST in one-hot encoding.\n", | |
" phase: The boolean tensor for batch-normalization is_training parameter.\n", | |
"'''\n", | |
"with tf.variable_scope('placeholder'):\n", | |
" images = tf.placeholder(\n", | |
" name = 'images',\n", | |
" shape = [None, FLAGS.resolution],\n", | |
" dtype = tf.float32\n", | |
" )\n", | |
" labels_gt = tf.placeholder(\n", | |
" name = 'labels_gt',\n", | |
" shape = [None, FLAGS.labels],\n", | |
" dtype = tf.float32\n", | |
" )\n", | |
" phase = tf.placeholder(\n", | |
" name = 'train',\n", | |
" shape = (),\n", | |
" dtype = tf.bool\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"'''Build the network.\n", | |
"\n", | |
"Input:\n", | |
" images: The input tensor from the placeholder scope.\n", | |
"\n", | |
"Structure:\n", | |
" layer_01\n", | |
" fully_connected\n", | |
" batch_normalization\n", | |
" relu\n", | |
" layer_02\n", | |
" fully_connected\n", | |
" batch_normalization\n", | |
" relu\n", | |
"\n", | |
"Output:\n", | |
" logits: The logits to be through softmax in shape [None, 10]\n", | |
"'''\n", | |
"net = images\n", | |
"\n", | |
"with tf.variable_scope('layer_01'):\n", | |
" net = tf.contrib.layers.fully_connected(net, 100, activation_fn = None, scope = 'fc')\n", | |
" net = tf.contrib.layers.batch_norm(net, center = True, scale = True, is_training = phase, scope = 'bn')\n", | |
" net = tf.nn.relu(net, 'relu')\n", | |
" \n", | |
"with tf.variable_scope('layer_02'):\n", | |
" net = tf.contrib.layers.fully_connected(net, 100, activation_fn = None, scope = 'fc')\n", | |
" net = tf.contrib.layers.batch_norm(net, center = True, scale = True, is_training = phase, scope = 'bn')\n", | |
" net = tf.nn.relu(net, 'relu')\n", | |
"\n", | |
"with tf.variable_scope('predict'):\n", | |
" logits = tf.contrib.layers.fully_connected(net, FLAGS.labels, activation_fn = None, scope = 'logits')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"'''Build the optional network.\n", | |
"\n", | |
"Tensor:\n", | |
" accuracy: The percentage of right classifications.\n", | |
" loss: The cross-entropy between groundtruth and predicted labels.\n", | |
" optimizer: The optimizer over loss.\n", | |
" \n", | |
" summary_accuracy:\n", | |
" summary_loss:\n", | |
" summary_merged: The log to be saved.\n", | |
"'''\n", | |
"with tf.name_scope('accuracy'):\n", | |
" accuracy = tf.reduce_mean(\n", | |
" tf.cast(\n", | |
" tf.equal(\n", | |
" tf.argmax(labels_gt, 1),\n", | |
" tf.argmax(logits, 1)\n", | |
" ),\n", | |
" dtype = tf.float32\n", | |
" )\n", | |
" )\n", | |
" summary_accuracy = tf.summary.scalar('accuracy', accuracy)\n", | |
" \n", | |
"with tf.name_scope('loss'):\n", | |
" loss = tf.reduce_mean(\n", | |
" tf.nn.softmax_cross_entropy_with_logits(\n", | |
" logits = logits,\n", | |
" labels = labels_gt\n", | |
" )\n", | |
" )\n", | |
" summary_loss = tf.summary.scalar('loss', loss)\n", | |
" \n", | |
"summary_merged = tf.summary.merge_all()\n", | |
" \n", | |
"update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n", | |
"with tf.control_dependencies(update_ops):\n", | |
" optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Restoring parameters from /home/nulledge/ckpt/MNIST/mnist.ckpt\n", | |
"Load failed. Initialize variables.\n" | |
] | |
} | |
], | |
"source": [ | |
"'''Build graph.\n", | |
"\n", | |
"Open session and load from FLAGS.checkpoint. If failed to load then initialize\n", | |
"all variables in the graph.\n", | |
"\n", | |
"Tensors:\n", | |
" sess: The session to interact.\n", | |
" saver: The saver which saves and loads the checkpoint in FLAGS.checkpoint.\n", | |
" writer: The log file writer.\n", | |
"'''\n", | |
"sess = tf.InteractiveSession()\n", | |
"saver = tf.train.Saver()\n", | |
"writer = tf.summary.FileWriter(FLAGS.summary, sess.graph)\n", | |
"\n", | |
"try:\n", | |
" saver.restore(sess, FLAGS.checkpoint)\n", | |
"except:\n", | |
" print('Load failed. Initialize variables.')\n", | |
" sess.run(tf.global_variables_initializer())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"epoch(1/10): 100%|██████████| 10000/10000 [00:01<00:00, 5812.23it/s, accuracy=0.96]\n", | |
"epoch(2/10): 100%|██████████| 10000/10000 [00:01<00:00, 9009.94it/s, accuracy=0.88]\n", | |
"epoch(3/10): 100%|██████████| 10000/10000 [00:01<00:00, 8882.72it/s, accuracy=0.92]\n", | |
"epoch(4/10): 100%|██████████| 10000/10000 [00:01<00:00, 9003.28it/s, accuracy=0.92]\n", | |
"epoch(5/10): 100%|██████████| 10000/10000 [00:01<00:00, 8941.69it/s, accuracy=0.92]\n", | |
"epoch(6/10): 100%|██████████| 10000/10000 [00:01<00:00, 8502.59it/s, accuracy=1] \n", | |
"epoch(7/10): 100%|██████████| 10000/10000 [00:01<00:00, 9013.86it/s, accuracy=0.92]\n", | |
"epoch(8/10): 100%|██████████| 10000/10000 [00:01<00:00, 8821.30it/s, accuracy=0.96]\n", | |
"epoch(9/10): 100%|██████████| 10000/10000 [00:01<00:00, 9028.97it/s, accuracy=0.96]\n", | |
"epoch(10/10): 100%|██████████| 10000/10000 [00:01<00:00, 9018.71it/s, accuracy=0.92]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"save path: /home/nulledge/ckpt/MNIST/mnist.ckpt\n" | |
] | |
} | |
], | |
"source": [ | |
"'''Train and save the network.\n", | |
"'''\n", | |
"idx = 0\n", | |
"for epoch in range(FLAGS.epoches):\n", | |
" train_iterator = tqdm(total = FLAGS.mnist_train)\n", | |
" train_iterator.set_description('epoch(' + str(epoch+1) + '/' + str(FLAGS.epoches) + ')')\n", | |
" for _ in range(FLAGS.mnist_train // FLAGS.batch):\n", | |
" train_images, train_labels = mnist.train.next_batch(FLAGS.batch)\n", | |
" _, train_accuracy, train_summary = sess.run([optimizer, accuracy, summary_merged],\n", | |
" feed_dict = {\n", | |
" images: train_images,\n", | |
" labels_gt: train_labels,\n", | |
" phase: True\n", | |
" })\n", | |
" train_iterator.set_postfix(accuracy = train_accuracy)\n", | |
" train_iterator.update(FLAGS.batch)\n", | |
" writer.add_summary(train_summary, idx)\n", | |
" idx += 1\n", | |
" train_iterator.close()\n", | |
"\n", | |
"saved_path = saver.save(sess, FLAGS.checkpoint)\n", | |
"print('save path:', saved_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"test: 100%|██████████| 5000/5000 [00:00<00:00, 28902.36it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mean accuracy: 0.9572\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"'''Test the network.\n", | |
"'''\n", | |
"\n", | |
"test_result = []\n", | |
"test_iterator = tqdm(total = FLAGS.mnist_test, desc = 'test')\n", | |
"for index in range(FLAGS.mnist_test // FLAGS.batch):\n", | |
" test_images, test_labels = mnist.test.next_batch(FLAGS.batch)\n", | |
" test_accuracy = sess.run([accuracy],\n", | |
" feed_dict = {\n", | |
" images: test_images,\n", | |
" labels_gt: test_labels,\n", | |
" phase: False\n", | |
" })\n", | |
" test_iterator.update(FLAGS.batch)\n", | |
" test_result.append(test_accuracy)\n", | |
"test_iterator.close()\n", | |
"\n", | |
"print('mean accuracy:', np.mean(test_result))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "tf-36", | |
"language": "python", | |
"name": "th-36" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
tensorboard result.