Last active
November 6, 2016 09:56
-
-
Save akimach/f1ffba2d1bf9c88e564322ad5a1fc90d to your computer and use it in GitHub Desktop.
MNIST double layer CNN classification for command line https://github.com/enakai00/jupyter_tfbook/blob/master/Chapter05/MNIST%20double%20layer%20CNN%20classification.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env/python | |
| # coding: utf-8 | |
| # **[CNN-01]** 必要なモジュールをインポートして、乱数のシードを設定します。 | |
| # In[1]: | |
| import time | |
| import tensorflow as tf | |
| import numpy as np | |
| from tensorflow.examples.tutorials.mnist import input_data | |
| np.random.seed(20160704) | |
| tf.set_random_seed(20160704) | |
| # **[CNN-02]** MNISTのデータセットを用意します。 | |
| # In[2]: | |
| mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |
| # **[CNN-03]** 1段目の畳み込みフィルターとプーリング層を定義します。 | |
| # In[3]: | |
| num_filters1 = 32 | |
| x = tf.placeholder(tf.float32, [None, 784]) | |
| x_image = tf.reshape(x, [-1,28,28,1]) | |
| W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1], | |
| stddev=0.1)) | |
| h_conv1 = tf.nn.conv2d(x_image, W_conv1, | |
| strides=[1,1,1,1], padding='SAME') | |
| b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1])) | |
| h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1) | |
| h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1], | |
| strides=[1,2,2,1], padding='SAME') | |
| # **[CNN-04]** 2段目の畳み込みフィルターとプーリング層を定義します。 | |
| # In[4]: | |
| num_filters2 = 64 | |
| W_conv2 = tf.Variable( | |
| tf.truncated_normal([5,5,num_filters1,num_filters2], | |
| stddev=0.1)) | |
| h_conv2 = tf.nn.conv2d(h_pool1, W_conv2, | |
| strides=[1,1,1,1], padding='SAME') | |
| b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2])) | |
| h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2) | |
| h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1], | |
| strides=[1,2,2,1], padding='SAME') | |
| # **[CNN-05]** 全結合層、ドロップアウト層、ソフトマックス関数を定義します。 | |
| # In[5]: | |
| h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2]) | |
| num_units1 = 7*7*num_filters2 | |
| num_units2 = 1024 | |
| w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2])) | |
| b2 = tf.Variable(tf.constant(0.1, shape=[num_units2])) | |
| hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2) | |
| keep_prob = tf.placeholder(tf.float32) | |
| hidden2_drop = tf.nn.dropout(hidden2, keep_prob) | |
| w0 = tf.Variable(tf.zeros([num_units2, 10])) | |
| b0 = tf.Variable(tf.zeros([10])) | |
| p = tf.nn.softmax(tf.matmul(hidden2_drop, w0) + b0) | |
| # **[CNN-06]** 誤差関数 loss、トレーニングアルゴリズム train_step、正解率 accuracy を定義します。 | |
| # In[6]: | |
| t = tf.placeholder(tf.float32, [None, 10]) | |
| loss = -tf.reduce_sum(t * tf.log(p)) | |
| train_step = tf.train.AdamOptimizer(0.0001).minimize(loss) | |
| correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1)) | |
| accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
| # **[CNN-07]** セッションを用意して、Variable を初期化します。 | |
| # In[7]: | |
| sess = tf.Session() | |
| sess.run(tf.initialize_all_variables()) | |
| saver = tf.train.Saver() | |
| # **[CNN-08]** パラメーターの最適化を20000回繰り返します。 | |
| # | |
| # 最終的に、テストセットに対して約99%の正解率が得られます。 | |
| # In[8]: | |
| start_time = time.time() | |
| i = 0 | |
| for _ in range(20000): | |
| i += 1 | |
| batch_xs, batch_ts = mnist.train.next_batch(50) | |
| sess.run(train_step, | |
| feed_dict={x:batch_xs, t:batch_ts, keep_prob:0.5}) | |
| if i % 500 == 0: | |
| loss_vals, acc_vals = [], [] | |
| for c in range(4): | |
| start = len(mnist.test.labels) / 4 * c | |
| end = len(mnist.test.labels) / 4 * (c+1) | |
| loss_val, acc_val = sess.run([loss, accuracy], | |
| feed_dict={x:mnist.test.images[start:end], | |
| t:mnist.test.labels[start:end], | |
| keep_prob:1.0}) | |
| loss_vals.append(loss_val) | |
| acc_vals.append(acc_val) | |
| loss_val = np.sum(loss_vals) | |
| acc_val = np.mean(acc_vals) | |
| print ('Step: %d, Loss: %f, Accuracy: %f' | |
| % (i, loss_val, acc_val)) | |
| saver.save(sess, 'cnn_session', global_step=i) | |
| end_time = time.time() | |
| print("Run time %i" % (end_time - start_time)) | |
| # **[CNN-09]** セッション情報を保存したファイルが生成されていることを確認します。 | |
| # In[9]: | |
| # get_ipython().system('ls cnn_session*') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment