Skip to content

Instantly share code, notes, and snippets.

@julie-is-late
Created September 26, 2016 14:21
Show Gist options
  • Save julie-is-late/d1bf2571754b45d9a26765de7ac26c31 to your computer and use it in GitHub Desktop.
Save julie-is-late/d1bf2571754b45d9a26765de7ac26c31 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"['train_labels', 'train_images']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load data\n",
"train = np.load('../data/MNIST_train_100.npz')\n",
"train.keys()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(100, 28, 28)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train['train_images'].shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHoAAAB6CAYAAABwWUfkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAADwpJREFUeJztnVlT20gXhl/JixZLsrBsvAEBHCaZSuYiVfn/VXM/F1OV\nDAmEOAYM3mVblrxIlvVdzNc9MhhCZjAQ1E+VygmFF+lxt06fPt1wQRCA8fzhH/sDMB4GJjoiMNER\ngYmOCEx0RGCiIwITHRGY6IjAREcEJjoiMNERgYmOCEx0RGCiI0L8sT8AAHAcx+ZK/yVBEHB3+T3W\noiMCEx0RmOiIwERHBCY6IjDREYGJjghMdERgoiMCEx0RmOiI8CRy3Y9FLBZDIpFAIpGApmlIp9PQ\nNA2KokBRFAiC8EOvt1gssFgs4HkeBoMBhsMhhsMhLMuCZVmYz+drOpPvE2nR8XgcoigilUphZ2cH\nOzs72N7eRqFQQKFQgKZp157DcX/PIaxaszafz+H7PmzbxunpKWq1Gs7OzlCv1zGZTJjoh4bjOHAc\nB0mSoOs6NjY2sLe3h19//RWvXr3C7u4u9vb2kMlklp5zlauy5/M5bc2Hh4fQdR2CIMD3fQyHQywW\nC/pleGgiJzoWiyGZTCKRSGBnZweVSgWVSgXb29vY3t7G1tYWstksksnk0vOI1NtaNMdxiMVikCQJ\nxWIRQRBAlmWk02kYhoHz83M0Gg00Go0Hlx1J0YIgQJZlbG9v4927d3j//j0ymQwymQzS6TREUbzx\n/nzbMmOe58FxHGRZRqlUgq7ryOVyMAwD+Xweh4eHAIBOp8NErwOe5xGPxxGLxWjrMgwDr1+/xps3\nb/Dbb79BkiTIsvzDAVgYcktIJpNIJpPQdR2qqkIURaTTacRiMcxmM4xGI5imCcdx4DjOPZ7pzURC\nNImqNU1b6q4rlQp2d3eRSqWQSCTA8/c/2kwkEtB1HRzH0a58e3sbnz9/xqdPn/D58+d7f89VREa0\nrusoFAp4/fo13r9/j/fv30PXdSiKAlmWabe7rvdWFAW6rmNnZwfv3r3D77//Dtu2mej7RBRF5HI5\nVCoVHBwcYH9/H7u7uxAEATzPIxaL0d8NgoCOh33fh+u6cF0Xs9mMHiRy9n0fyWQSgiAgmUwiHo8j\nHo/TnwmCQG8Z5P+qqiKXy6FarSKbzUKW5aXXWxeREU1a8/7+PrLZLOLx+MpWHAQB5vM55vM5xuMx\nBoMB+v0++v0+TNNEr9fDeDzGdDrFZDLBxsYGHaKlUikoigJVVZHNZmEYBuLxfy4xicoTiQQkSYKq\nqjAMA47jYDweM9H/FTLcef36NSqVCgzDuPWe7Ps+DZqazSYuLy9Rr9dxdnaGs7Mz9Pt9jEYjjEYj\nlEollMtllMtlZDIZGIaBzc1N8DwPVVUhyzJ9XSKajOE1TUMmkwHHcfA8D9PpdG3X4NmK5nkekiRB\nFEXk83lsbm4il8tB13VIkkQjZOCf7tr3fUynUzSbTXpcXl6i0Wig2Wyi1Wqh1WrBtm2Mx2M4joPF\nYkG/FLqu01iA4zioqrrUpZP3I59tY2MDhUIBi8UC4/EYo9Fobdfj2YqOx+PQNA2GYaBcLmNzc5N2\nr8lk8lqXPZ/P4bouLMvCt2/f8Ndff6FaraLdbqPdbmMwGNDh0Gw2g+d58DwP/X4fruui3+8jlUoh\nlUqhWCxClmXk83koigJJkmgsQN6XiC4Wi5hMJjBNc73XY62v/ogQ0cVicUm0oigro2vSXRPRf/zx\nBz5+/AjTNGGaJmaz2cr3cV0Xg8EAAOj4udPpIJ/P4+XLl8jlcuB5HqIo0vflOA6iKGJjYwOlUgmm\naf6n8ftdeNbTlDzP05ZE/h3ussO4rovxeAzLsmgANhwOMZlM7hwkkVz2dDrFaDRCt9ulwdtj5LfD\nPNsWTYTeRXIQBPA8D+PxGMPhEIPBAKZpYjAYwPO8O0vyfR9BEGA6ncKyLPR6PZimCUVRmOinAmmJ\njuPAtm1YlgXbtn/oNUhQR+73k8kE0+kUnuchCIKlPDkZo08mE7iui8Vicd+ntMSzFU0KAMbjMR33\nep6H+XxOW/h9Q5IjqVQKhmFga2sLpVIJ6XSajqeJcNu20Wg08OXLFzSbzbXnvJ+t6CAIaIshol3X\npV3oOkSTDBgRXS6XUSqVkEqlrokejUZUtGVZGI/H9/55wjxb0YvFgt4rB4MBLMvCaDSCoigQRXEp\nYwUsSyLDHtu2l9Ke5FgFx3E08BMEAYqi0KwZSc4QyUEQYDabYTAYoNlswvM8uK671uvxbEX7vk9b\nSbvdpskOQRCwsbGxcrijaRry+TxevXqF+XyOXC6HdruNTqeD4XAI27Zh2/aN91Mimwgn9WgkECT3\n6CAI4Ps+5vM5ZrMZfN9n9+h/CxE9m83Q6XSobDI/rOv60u8LgkBFeZ6HVCqFfD6Pk5MTnJycgOd5\nmsG6TUq4ZV/NiIUhMYTrutcCtXXwbEUDoDNCtm2j1+vh4uICmqZBVdVrF5aUGCmKgnw+j0QiAVmW\n6eRDoVBAt9tFp9NZ2X1zHAdBECCKIorFIs2OhWOB8HuGZ8kegmctmuC6LkzTxPn5OTY2NpDP569d\nYCIkmUxC0zQqWtM0lMtlOntlmuaNoklXres69vf3kUqlHuT87kJkRPf7fZyfnyOXy2F/f3+pdZGu\nNRaL0XsrmUIsl8vwPA+O49A67ZuSH+T+LIoiDMOAoigPcn53IRKi5/M5LMtaCsra7TY0TaPdbfg+\nGp5l4nkeiUSC3nMlSbo1GOM4DvF4HIqiXIvsF4vFUiHDQ2bLIiHa8zzYtg2e55emIH3fh67rEEXx\nxueSiDmZTILneQiCcGPgFP6CJJPJpcoVAHQadDweP3hBfyREz+dz2LYNz/PQarXQaDRQr9cB/N1d\nk6DpasVJODdOxsc/SvhLMZ/PMZlMYFkWHMdZ+9g5TCREk/IgAOh2uzg6OkIQBLQypFgswjAMZLNZ\npNPpWydA/stnsG0bFxcXOD09xenpKYbD4b29/veIhGgy0bBYLNDpdBAEAXq9HsrlMra2trC1tYWD\ngwPwPI9UKkWDsvuCtGrHcXBxcYHDw0Mmel2QMatlWZhOp+h0OrTYr9vt0orORCKBeDxOH8kR7trD\njwBubf3hxEiv18P5+TmOjo5Qr9dhWdaDnX9kRBNI6waAwWCAIAhotsu2bZydndHxMFk3lU6nIcvy\nUmlvIpGgAddtPcBsNqPj75OTE1SrVdRqNbTb7QdbpQFEWLTv++j3+7BtG91uF6PRCK1WC9lslkok\nyRJS4UlqwmRZpsmQRCJBCxxWQVpyrVZbEj0ajeB53oOdd+REA6DjYFIn5jgOfN+H4zhotVq029Y0\njXbvhmFAVVVagE/SpGRdV5hwpE1qysgqym63i8FgcGMN2rqIpOirkGlDy7Iwm81o+RHZtaBery/t\nhHBwcIDFYkHrw28aV5No33EcmKaJ4XCI6XS69gmMVTDR/4cUJlwNsEjVCNkZQVEUTKdTuhpzsVis\nFEdmpMiKD1JsyEQ/MmSOOJyW5DgO8/l8aW6ZtGpJkmhxPikqIJAEjW3bqNVqOD09xdnZGS3+f6gZ\nqzBM9C1cXetcLpfx4sUL7O7uIpfLQZblpaU9RLbruuh0Ori4uMDx8TGOj49RrVbRaDSY6KcImdAg\nqyq2trbwyy+/UNFkBUYYUqvW6XRwcnKCw8NDfPnyBdVqFcPh8FG6bYCJppCxMCk+UFWV7o6QyWRQ\nKpXozkWlUgmapl1bv0XKg2zbxuXlJY6OjnB8fIxWq4XZbPZokgEmmhIuDiyVSvTY2tqiS3qIdFVV\nVy7tIbXaRPTnz59xdHT04GPmVURKdLgYMFzIx/M8TYLouo4XL16gUqlgf38fe3t72N3dRSaTgSRJ\n17prEl2TVZW2bcM0TTQaDVSrVZyenj7W6S4RCdHhYRIJrkhJrqqq0DSNLmbPZDIoFosoFArI5/PI\nZrPQNG1liTDwz95i0+mUJkW+fv2Ki4sLTCaTRzjb1URCNMleJZNJOhYmRYBhqUQs2SoylUpBFEUq\nedUOCWSOeTQaoV6v4/j4GEdHR0z0v+VqGW348XuEBZMWrGna0m4FRLRhGEv7jxDCgRRZG+15Hl2U\n1+v1cHJygk+fPuHLly9otVpr3cHgR/kpRIeL7kggRB7vUoAnyzLdU0TTNHo/TqfTdJcCkscmgm+a\ndgyCgKY0e70e7a4vLy/p1heNRgODweBBK0i+x5MXHW7Jsiwjk8nQbSrI8T3S6TQtMiAyBUG4NvdM\neonblteGF8jVajV8/foV1WoV3759o9ORJMp+zE1er/LkRZOKSlVVkc/nsb29jZ2dHbr1omEY3y35\nIQX4hUKBLni7rdUSwisoyEL5yWSCWq1G78WkLKher9OVm0+pJROevGhBELC5uYlyuYy9vT0cHBzg\n5cuX0HV9aV74e6+hqiokSaJB1V0hQyfbtmn16PHxMT59+oTDw0OYpknntX9k0fxD8+RFJ5NJ5HI5\nvHz5Em/evMHbt2/x9u1bOp69i7TwmPlHCv7CC9tJEuT4+BgfP37Ehw8f8OHDB1rEQGaxHjP7dRtP\nXnR4EzZRFCFJElKpFCRJ+s+vHV7dSISFF8+TXYgcx0Gz2UStVqNHu93GeDx+smKv8uRFE0jrus8L\nG16v7LouptMpbNumW051Op2lg/yM7G/ys0gGfhLRpIIz3D1evci3dcm3VYCEN5IbjUbo9/s4Ozuj\ntV1kyGSaJmzbppvI/Ww8edGz2QztdpsW4RGhuVyO7igQzmGvgnTJJHIm3fFoNKJ/2IT8jKzR6nQ6\n6Ha7tMZrPB7TTWd+Rn4a0Y7j0Au9WCxQqVSwt7cHTdOWyn9WsVgsMJlM4DgOer0eXWxHkh2NRoPe\nl8kQihzT6XRp/xMmek2QbRj7/T5836dlO57n0c1Tw7NQqyB7mViWhWaziXq9jnq9vvSXbMJCnyNP\nXnSY8XiMRqOBIAgwGAxwenqKP//887trpcjsElngRr44ZJXGZDJ50mPg+4B7Cl0Rx3F3+hDhIRZ5\nJMOs24IxEnCF79VXN1xfR1T/EARBcKfEwE8lmnGdu4p+1pu+Mv6BiY4ITHREYKIjAhMdEZjoiPAk\nhleM9cNadERgoiMCEx0RmOiIwERHBCY6IjDREYGJjghMdERgoiMCEx0RmOiIwERHBCY6IjDREYGJ\njghMdERgoiMCEx0RmOiIwERHBCY6IjDREYGJjghMdERgoiMCEx0RmOiIwERHhP8BFSIrmavdsoIA\nAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8278781c18>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(1,1))\n",
"plt.axis('off')\n",
"plt.imshow(train['train_images'][3], cmap=plt.cm.gray)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(100, 784)\n",
"(100, 2)\n"
]
}
],
"source": [
"inputs = train['train_images'].reshape(100, 784)\n",
"outputs = pd.get_dummies(train['train_labels']).iloc[:,:].values\n",
"print(inputs.shape)\n",
"print(outputs.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## (a)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x = tf.constant(inputs, dtype='float32', shape=[100, 784])\n",
"y = tf.constant(outputs, dtype='float32', shape=[100, 2])\n",
"\n",
"w = tf.Variable(tf.truncated_normal([784,2], stddev=0.1))\n",
"b = tf.Variable(tf.truncated_normal([1,2], stddev=0.1))\n",
"\n",
"MSE = tf.reduce_mean(tf.square(tf.nn.softmax(tf.matmul(x,w) + b) - y))\n",
"\n",
"optimizer = tf.train.AdamOptimizer().minimize(MSE)\n",
"\n",
"y_pred = tf.nn.softmax(tf.matmul(x,w) + b)\n",
"\n",
"# create a graph session and initialize it\n",
"init = tf.initialize_all_variables()\n",
"sess = tf.Session()\n",
"sess.run(init)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step = 0 MSE = 0.318325 \n",
"step = 100 MSE = 0.045333 \n",
"step = 200 MSE = 0.018782 \n",
"step = 300 MSE = 0.010283 \n",
"step = 400 MSE = 0.006527 \n",
"step = 500 MSE = 0.004537 \n",
"step = 600 MSE = 0.003351 \n",
"step = 700 MSE = 0.002585 \n",
"step = 800 MSE = 0.002058 \n",
"step = 900 MSE = 0.001679 \n",
"step = 1000 MSE = 0.001398 \n",
"done !\n"
]
}
],
"source": [
"MAXSTEPS = 1000\n",
"for step in range(MAXSTEPS + 1):\n",
" (_,mse) = sess.run([optimizer,MSE])\n",
" if (step % (MAXSTEPS / 10)) == 0:\n",
" print('step = %-5d MSE = %-10f' % (step,mse))\n",
"\n",
"print('done !')\n",
"\n",
"tf_output_pred = sess.run(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"tf_mse = ((outputs - tf_output_pred)**2).mean()\n",
"tf_std = tf_output_pred.std()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## (b)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import accuracy_score\n",
"LgR = LogisticRegression(solver='lbfgs', multi_class='multinomial')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
" intercept_scaling=1, max_iter=100, multi_class='multinomial',\n",
" n_jobs=1, penalty='l2', random_state=None, solver='lbfgs',\n",
" tol=0.0001, verbose=0, warm_start=False)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"LgR.fit(inputs,train['train_labels'])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sk_outputs_pred = LgR.predict_proba(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# check Sklearn mean square error\n",
"sk_mse = ((outputs - sk_outputs_pred)**2).mean()\n",
"sk_std = sk_outputs_pred.std()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## (c)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf mse: 0.00139507375488\n",
"sk mse: 0.0192015802169\n",
"\n",
"tf rmse: 0.0373506861366\n",
"sk rmse: 0.138569766605\n"
]
}
],
"source": [
"print('tf mse: ', tf_mse)\n",
"print('sk mse: ', sk_mse)\n",
"print()\n",
"print('tf rmse: ', np.sqrt(tf_mse))\n",
"print('sk rmse: ', np.sqrt(sk_mse))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## (d)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensorflow error rate = 0.0\n",
"sklearn error rate = 0.0\n"
]
}
],
"source": [
"tf_acc = (tf_output_pred.argmax(axis=1) == outputs.argmax(axis=1)).sum() / outputs.shape[0]\n",
"sk_acc = (sk_outputs_pred.argmax(axis=1) == outputs.argmax(axis=1)).sum() / outputs.shape[0]\n",
"\n",
"print('tensorflow error rate = ', 1 - tf_acc)\n",
"print('sklearn error rate = ', 1 - sk_acc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## (e)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"['test_images', 'test_labels']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test = np.load('../data/MNIST_test_100.npz')\n",
"test.keys()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"test_in = test['test_images'].reshape(100, 784)\n",
"test_out = pd.get_dummies(test['test_labels']).iloc[:,:].values"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"tf_test_outputs_pred = sess.run(y_pred, feed_dict={x: test_in})"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sk_test_outputs_pred = LgR.predict_proba(test_in)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"tf_test_mse = ((test_out - tf_test_outputs_pred)**2).mean()\n",
"tf_test_std = tf_test_outputs_pred.std()\n",
"\n",
"sk_test_mse = ((test_out - sk_test_outputs_pred)**2).mean()\n",
"sk_test_std = sk_test_outputs_pred.std()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf mse: 0.0831381725287\n",
"sk mse: 0.353527929652\n",
"\n",
"tf rmse: 0.288336908024\n",
"sk rmse: 0.594582147102\n"
]
}
],
"source": [
"print('tf mse: ', tf_test_mse)\n",
"print('sk mse: ', sk_test_std)\n",
"print()\n",
"print('tf rmse: ', np.sqrt(tf_test_mse))\n",
"print('sk rmse: ', np.sqrt(sk_test_std))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensorflow error rate = 0.11\n",
"sklearn error rate = 0.11\n"
]
}
],
"source": [
"tf_test_acc = (tf_test_outputs_pred.argmax(axis=1) == test_out.argmax(axis=1)).sum() / test_out.shape[0]\n",
"sk_test_acc = (sk_test_outputs_pred.argmax(axis=1) == test_out.argmax(axis=1)).sum() / test_out.shape[0]\n",
"\n",
"print('tensorflow error rate = ', 1 - tf_test_acc)\n",
"print('sklearn error rate = ', 1 - sk_test_acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment