Skip to content

Instantly share code, notes, and snippets.

@Erlemar
Last active August 18, 2017 14:26
Show Gist options
  • Save Erlemar/9ca303ed987acb85d5bb04c6386ad879 to your computer and use it in GitHub Desktop.
Save Erlemar/9ca303ed987acb85d5bb04c6386ad879 to your computer and use it in GitHub Desktop.
CNN for my project
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def model(X, w, w3, w4, w_o, p_keep_conv, p_keep_hidden):\n",
" l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME') + b1)\n",
" l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n",
" l1 = tf.nn.dropout(l1, p_keep_conv)\n",
"\n",
" l3a = tf.nn.relu(tf.nn.conv2d(l1, w3, strides=[1, 1, 1, 1], padding='SAME') + b3)\n",
" l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n",
" l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])\n",
" l3 = tf.nn.dropout(l3, p_keep_conv)\n",
"\n",
" l4 = tf.nn.relu(tf.matmul(l3, w4) + b4)\n",
" l4 = tf.nn.dropout(l4, p_keep_hidden)\n",
"\n",
" pyx = tf.matmul(l4, w_o) + b5\n",
" return pyx"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"init_op = tf.global_variables_initializer()\n",
"\n",
"X = tf.placeholder(\"float\", [None, 28, 28, 1])\n",
"Y = tf.placeholder(\"float\", [None, 10])\n",
"\n",
"w = tf.get_variable(\"w\", shape=[4, 4, 1, 16], initializer=tf.contrib.layers.xavier_initializer())\n",
"b1 = tf.get_variable(name=\"b1\", shape=[16], initializer=tf.zeros_initializer())\n",
"w3 = tf.get_variable(\"w3\", shape=[4, 4, 16, 32], initializer=tf.contrib.layers.xavier_initializer())\n",
"b3 = tf.get_variable(name=\"b3\", shape=[32], initializer=tf.zeros_initializer())\n",
"w4 = tf.get_variable(\"w4\", shape=[32 * 7 * 7, 625], initializer=tf.contrib.layers.xavier_initializer())\n",
"b4 = tf.get_variable(name=\"b4\", shape=[625], initializer=tf.zeros_initializer())\n",
"w_o = tf.get_variable(\"w_o\", shape=[625, 10], initializer=tf.contrib.layers.xavier_initializer())\n",
"b5 = tf.get_variable(name=\"b5\", shape=[10], initializer=tf.zeros_initializer())\n",
"\n",
"p_keep_conv = tf.placeholder(\"float\")\n",
"p_keep_hidden = tf.placeholder(\"float\")\n",
"py_x = model(X, w, w3, w4, w_o, p_keep_conv, p_keep_hidden)\n",
"\n",
"reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)\n",
"reg_constant = 0.01\n",
"\n",
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y) + reg_constant * sum(reg_losses))\n",
"\n",
"train_op = tf.train.RMSPropOptimizer(0.0001, 0.9).minimize(cost)\n",
"predict_op = tf.argmax(py_x, 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#Training\n",
"train_acc = []\n",
"val_acc = []\n",
"test_acc = []\n",
"train_loss = []\n",
"val_loss = []\n",
"test_loss = []\n",
"with tf.Session() as sess:\n",
" tf.global_variables_initializer().run()\n",
"\n",
" for i in range(256):\n",
" training_batch = zip(range(0, len(trX), batch_size),\n",
" range(batch_size, len(trX)+1, batch_size))\n",
" for start, end in training_batch:\n",
" sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],\n",
" p_keep_conv: 0.8, p_keep_hidden: 0.5})\n",
"\n",
" train_acc = np.mean(np.argmax(trY, axis=1) ==\n",
" sess.run(predict_op, feed_dict={X: trX,\n",
" Y: trY,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0}))\n",
" train_acc.append(train_acc)\n",
" \n",
" val_acc = np.mean(np.argmax(teY, axis=1) ==\n",
" sess.run(predict_op, feed_dict={X: teX,\n",
" Y: teY,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0}))\n",
" val_acc.append(val_acc)\n",
" test_acc = np.mean(np.argmax(mnist.test.labels, axis=1) ==\n",
" sess.run(predict_op, feed_dict={X: mnist_test_images,\n",
" Y: mnist.test.labels,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0}))\n",
" test_acc.append(test_acc)\n",
" print('Step {0}. Train accuracy: {3}. Validation accuracy: {1}. Test accuracy: {2}.'.format(i, val_acc, test_acc, train_acc))\n",
" \n",
" _, loss_train = sess.run([predict_op, cost],\n",
" feed_dict={X: trX,\n",
" Y: trY,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0})\n",
" train_loss.append(loss_train)\n",
" _, loss_val = sess.run([predict_op, cost],\n",
" feed_dict={X: teX,\n",
" Y: teY,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0})\n",
" val_loss.append(loss_val)\n",
" _, loss_test = sess.run([predict_op, cost],\n",
" feed_dict={X: mnist_test_images,\n",
" Y: mnist.test.labels,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0})\n",
" test_loss.append(loss_test)\n",
" print('Train loss: {0}. Validation loss: {1}. Test loss: {2}.'.format(loss_train, loss_val, loss_test))\n",
" \n",
" all_saver = tf.train.Saver() \n",
" all_saver.save(sess, '/resources/data.chkp')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#Predicting\n",
"with tf.Session() as sess:\n",
" saver = tf.train.Saver()\n",
" saver.restore(sess, \"./data.chkp\")\n",
" pr = sess.run(predict_op, feed_dict={X: mnist_test_images,\n",
" Y: mnist.test.labels,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0})\n",
" print(np.mean(np.argmax(mnist.test.labels, axis=1) ==\n",
" sess.run(predict_op, feed_dict={X: mnist_test_images,\n",
" Y: mnist.test.labels,\n",
" p_keep_conv: 1.0,\n",
" p_keep_hidden: 1.0})))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [35_for_DL]",
"language": "python",
"name": "Python [35_for_DL]"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment