Created
September 26, 2016 14:21
-
-
Save julie-is-late/d1bf2571754b45d9a26765de7ac26c31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import tensorflow as tf\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['train_labels', 'train_images']" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# load data\n", | |
"train = np.load('../data/MNIST_train_100.npz')\n", | |
"train.keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(100, 28, 28)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train['train_images'].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHoAAAB6CAYAAABwWUfkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAADwpJREFUeJztnVlT20gXhl/JixZLsrBsvAEBHCaZSuYiVfn/VXM/F1OV\nDAmEOAYM3mVblrxIlvVdzNc9MhhCZjAQ1E+VygmFF+lxt06fPt1wQRCA8fzhH/sDMB4GJjoiMNER\ngYmOCEx0RGCiIwITHRGY6IjAREcEJjoiMNERgYmOCEx0RGCiI0L8sT8AAHAcx+ZK/yVBEHB3+T3W\noiMCEx0RmOiIwERHBCY6IjDREYGJjghMdERgoiMCEx0RmOiI8CRy3Y9FLBZDIpFAIpGApmlIp9PQ\nNA2KokBRFAiC8EOvt1gssFgs4HkeBoMBhsMhhsMhLMuCZVmYz+drOpPvE2nR8XgcoigilUphZ2cH\nOzs72N7eRqFQQKFQgKZp157DcX/PIaxaszafz+H7PmzbxunpKWq1Gs7OzlCv1zGZTJjoh4bjOHAc\nB0mSoOs6NjY2sLe3h19//RWvXr3C7u4u9vb2kMlklp5zlauy5/M5bc2Hh4fQdR2CIMD3fQyHQywW\nC/pleGgiJzoWiyGZTCKRSGBnZweVSgWVSgXb29vY3t7G1tYWstksksnk0vOI1NtaNMdxiMVikCQJ\nxWIRQRBAlmWk02kYhoHz83M0Gg00Go0Hlx1J0YIgQJZlbG9v4927d3j//j0ymQwymQzS6TREUbzx\n/nzbMmOe58FxHGRZRqlUgq7ryOVyMAwD+Xweh4eHAIBOp8NErwOe5xGPxxGLxWjrMgwDr1+/xps3\nb/Dbb79BkiTIsvzDAVgYcktIJpNIJpPQdR2qqkIURaTTacRiMcxmM4xGI5imCcdx4DjOPZ7pzURC\nNImqNU1b6q4rlQp2d3eRSqWQSCTA8/c/2kwkEtB1HRzH0a58e3sbnz9/xqdPn/D58+d7f89VREa0\nrusoFAp4/fo13r9/j/fv30PXdSiKAlmWabe7rvdWFAW6rmNnZwfv3r3D77//Dtu2mej7RBRF5HI5\nVCoVHBwcYH9/H7u7uxAEATzPIxaL0d8NgoCOh33fh+u6cF0Xs9mMHiRy9n0fyWQSgiAgmUwiHo8j\nHo/TnwmCQG8Z5P+qqiKXy6FarSKbzUKW5aXXWxeREU1a8/7+PrLZLOLx+MpWHAQB5vM55vM5xuMx\nBoMB+v0++v0+TNNEr9fDeDzGdDrFZDLBxsYGHaKlUikoigJVVZHNZmEYBuLxfy4xicoTiQQkSYKq\nqjAMA47jYDweM9H/FTLcef36NSqVCgzDuPWe7Ps+DZqazSYuLy9Rr9dxdnaGs7Mz9Pt9jEYjjEYj\nlEollMtllMtlZDIZGIaBzc1N8DwPVVUhyzJ9XSKajOE1TUMmkwHHcfA8D9PpdG3X4NmK5nkekiRB\nFEXk83lsbm4il8tB13VIkkQjZOCf7tr3fUynUzSbTXpcXl6i0Wig2Wyi1Wqh1WrBtm2Mx2M4joPF\nYkG/FLqu01iA4zioqrrUpZP3I59tY2MDhUIBi8UC4/EYo9Fobdfj2YqOx+PQNA2GYaBcLmNzc5N2\nr8lk8lqXPZ/P4bouLMvCt2/f8Ndff6FaraLdbqPdbmMwGNDh0Gw2g+d58DwP/X4fruui3+8jlUoh\nlUqhWCxClmXk83koigJJkmgsQN6XiC4Wi5hMJjBNc73XY62v/ogQ0cVicUm0oigro2vSXRPRf/zx\nBz5+/AjTNGGaJmaz2cr3cV0Xg8EAAOj4udPpIJ/P4+XLl8jlcuB5HqIo0vflOA6iKGJjYwOlUgmm\naf6n8ftdeNbTlDzP05ZE/h3ussO4rovxeAzLsmgANhwOMZlM7hwkkVz2dDrFaDRCt9ulwdtj5LfD\nPNsWTYTeRXIQBPA8D+PxGMPhEIPBAKZpYjAYwPO8O0vyfR9BEGA6ncKyLPR6PZimCUVRmOinAmmJ\njuPAtm1YlgXbtn/oNUhQR+73k8kE0+kUnuchCIKlPDkZo08mE7iui8Vicd+ntMSzFU0KAMbjMR33\nep6H+XxOW/h9Q5IjqVQKhmFga2sLpVIJ6XSajqeJcNu20Wg08OXLFzSbzbXnvJ+t6CAIaIshol3X\npV3oOkSTDBgRXS6XUSqVkEqlrokejUZUtGVZGI/H9/55wjxb0YvFgt4rB4MBLMvCaDSCoigQRXEp\nYwUsSyLDHtu2l9Ke5FgFx3E08BMEAYqi0KwZSc4QyUEQYDabYTAYoNlswvM8uK671uvxbEX7vk9b\nSbvdpskOQRCwsbGxcrijaRry+TxevXqF+XyOXC6HdruNTqeD4XAI27Zh2/aN91Mimwgn9WgkECT3\n6CAI4Ps+5vM5ZrMZfN9n9+h/CxE9m83Q6XSobDI/rOv60u8LgkBFeZ6HVCqFfD6Pk5MTnJycgOd5\nmsG6TUq4ZV/NiIUhMYTrutcCtXXwbEUDoDNCtm2j1+vh4uICmqZBVdVrF5aUGCmKgnw+j0QiAVmW\n6eRDoVBAt9tFp9NZ2X1zHAdBECCKIorFIs2OhWOB8HuGZ8kegmctmuC6LkzTxPn5OTY2NpDP569d\nYCIkmUxC0zQqWtM0lMtlOntlmuaNoklXres69vf3kUqlHuT87kJkRPf7fZyfnyOXy2F/f3+pdZGu\nNRaL0XsrmUIsl8vwPA+O49A67ZuSH+T+LIoiDMOAoigPcn53IRKi5/M5LMtaCsra7TY0TaPdbfg+\nGp5l4nkeiUSC3nMlSbo1GOM4DvF4HIqiXIvsF4vFUiHDQ2bLIiHa8zzYtg2e55emIH3fh67rEEXx\nxueSiDmZTILneQiCcGPgFP6CJJPJpcoVAHQadDweP3hBfyREz+dz2LYNz/PQarXQaDRQr9cB/N1d\nk6DpasVJODdOxsc/SvhLMZ/PMZlMYFkWHMdZ+9g5TCREk/IgAOh2uzg6OkIQBLQypFgswjAMZLNZ\npNPpWydA/stnsG0bFxcXOD09xenpKYbD4b29/veIhGgy0bBYLNDpdBAEAXq9HsrlMra2trC1tYWD\ngwPwPI9UKkWDsvuCtGrHcXBxcYHDw0Mmel2QMatlWZhOp+h0OrTYr9vt0orORCKBeDxOH8kR7trD\njwBubf3hxEiv18P5+TmOjo5Qr9dhWdaDnX9kRBNI6waAwWCAIAhotsu2bZydndHxMFk3lU6nIcvy\nUmlvIpGgAddtPcBsNqPj75OTE1SrVdRqNbTb7QdbpQFEWLTv++j3+7BtG91uF6PRCK1WC9lslkok\nyRJS4UlqwmRZpsmQRCJBCxxWQVpyrVZbEj0ajeB53oOdd+REA6DjYFIn5jgOfN+H4zhotVq029Y0\njXbvhmFAVVVagE/SpGRdV5hwpE1qysgqym63i8FgcGMN2rqIpOirkGlDy7Iwm81o+RHZtaBery/t\nhHBwcIDFYkHrw28aV5No33EcmKaJ4XCI6XS69gmMVTDR/4cUJlwNsEjVCNkZQVEUTKdTuhpzsVis\nFEdmpMiKD1JsyEQ/MmSOOJyW5DgO8/l8aW6ZtGpJkmhxPikqIJAEjW3bqNVqOD09xdnZGS3+f6gZ\nqzBM9C1cXetcLpfx4sUL7O7uIpfLQZblpaU9RLbruuh0Ori4uMDx8TGOj49RrVbRaDSY6KcImdAg\nqyq2trbwyy+/UNFkBUYYUqvW6XRwcnKCw8NDfPnyBdVqFcPh8FG6bYCJppCxMCk+UFWV7o6QyWRQ\nKpXozkWlUgmapl1bv0XKg2zbxuXlJY6OjnB8fIxWq4XZbPZokgEmmhIuDiyVSvTY2tqiS3qIdFVV\nVy7tIbXaRPTnz59xdHT04GPmVURKdLgYMFzIx/M8TYLouo4XL16gUqlgf38fe3t72N3dRSaTgSRJ\n17prEl2TVZW2bcM0TTQaDVSrVZyenj7W6S4RCdHhYRIJrkhJrqqq0DSNLmbPZDIoFosoFArI5/PI\nZrPQNG1liTDwz95i0+mUJkW+fv2Ki4sLTCaTRzjb1URCNMleJZNJOhYmRYBhqUQs2SoylUpBFEUq\nedUOCWSOeTQaoV6v4/j4GEdHR0z0v+VqGW348XuEBZMWrGna0m4FRLRhGEv7jxDCgRRZG+15Hl2U\n1+v1cHJygk+fPuHLly9otVpr3cHgR/kpRIeL7kggRB7vUoAnyzLdU0TTNHo/TqfTdJcCkscmgm+a\ndgyCgKY0e70e7a4vLy/p1heNRgODweBBK0i+x5MXHW7Jsiwjk8nQbSrI8T3S6TQtMiAyBUG4NvdM\neonblteGF8jVajV8/foV1WoV3759o9ORJMp+zE1er/LkRZOKSlVVkc/nsb29jZ2dHbr1omEY3y35\nIQX4hUKBLni7rdUSwisoyEL5yWSCWq1G78WkLKher9OVm0+pJROevGhBELC5uYlyuYy9vT0cHBzg\n5cuX0HV9aV74e6+hqiokSaJB1V0hQyfbtmn16PHxMT59+oTDw0OYpknntX9k0fxD8+RFJ5NJ5HI5\nvHz5Em/evMHbt2/x9u1bOp69i7TwmPlHCv7CC9tJEuT4+BgfP37Ehw8f8OHDB1rEQGaxHjP7dRtP\nXnR4EzZRFCFJElKpFCRJ+s+vHV7dSISFF8+TXYgcx0Gz2UStVqNHu93GeDx+smKv8uRFE0jrus8L\nG16v7LouptMpbNumW051Op2lg/yM7G/ys0gGfhLRpIIz3D1evci3dcm3VYCEN5IbjUbo9/s4Ozuj\ntV1kyGSaJmzbppvI/Ww8edGz2QztdpsW4RGhuVyO7igQzmGvgnTJJHIm3fFoNKJ/2IT8jKzR6nQ6\n6Ha7tMZrPB7TTWd+Rn4a0Y7j0Au9WCxQqVSwt7cHTdOWyn9WsVgsMJlM4DgOer0eXWxHkh2NRoPe\nl8kQihzT6XRp/xMmek2QbRj7/T5836dlO57n0c1Tw7NQqyB7mViWhWaziXq9jnq9vvSXbMJCnyNP\nXnSY8XiMRqOBIAgwGAxwenqKP//887trpcjsElngRr44ZJXGZDJ50mPg+4B7Cl0Rx3F3+hDhIRZ5\nJMOs24IxEnCF79VXN1xfR1T/EARBcKfEwE8lmnGdu4p+1pu+Mv6BiY4ITHREYKIjAhMdEZjoiPAk\nhleM9cNadERgoiMCEx0RmOiIwERHBCY6IjDREYGJjghMdERgoiMCEx0RmOiIwERHBCY6IjDREYGJ\njghMdERgoiMCEx0RmOiIwERHBCY6IjDREYGJjghMdERgoiMCEx0RmOiIwERHhP8BFSIrmavdsoIA\nAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f8278781c18>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(1,1))\n", | |
"plt.axis('off')\n", | |
"plt.imshow(train['train_images'][3], cmap=plt.cm.gray)\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(100, 784)\n", | |
"(100, 2)\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs = train['train_images'].reshape(100, 784)\n", | |
"outputs = pd.get_dummies(train['train_labels']).iloc[:,:].values\n", | |
"print(inputs.shape)\n", | |
"print(outputs.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## (a)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"x = tf.constant(inputs, dtype='float32', shape=[100, 784])\n", | |
"y = tf.constant(outputs, dtype='float32', shape=[100, 2])\n", | |
"\n", | |
"w = tf.Variable(tf.truncated_normal([784,2], stddev=0.1))\n", | |
"b = tf.Variable(tf.truncated_normal([1,2], stddev=0.1))\n", | |
"\n", | |
"MSE = tf.reduce_mean(tf.square(tf.nn.softmax(tf.matmul(x,w) + b) - y))\n", | |
"\n", | |
"optimizer = tf.train.AdamOptimizer().minimize(MSE)\n", | |
"\n", | |
"y_pred = tf.nn.softmax(tf.matmul(x,w) + b)\n", | |
"\n", | |
"# create a graph session and initialize it\n", | |
"init = tf.initialize_all_variables()\n", | |
"sess = tf.Session()\n", | |
"sess.run(init)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"step = 0 MSE = 0.318325 \n", | |
"step = 100 MSE = 0.045333 \n", | |
"step = 200 MSE = 0.018782 \n", | |
"step = 300 MSE = 0.010283 \n", | |
"step = 400 MSE = 0.006527 \n", | |
"step = 500 MSE = 0.004537 \n", | |
"step = 600 MSE = 0.003351 \n", | |
"step = 700 MSE = 0.002585 \n", | |
"step = 800 MSE = 0.002058 \n", | |
"step = 900 MSE = 0.001679 \n", | |
"step = 1000 MSE = 0.001398 \n", | |
"done !\n" | |
] | |
} | |
], | |
"source": [ | |
"MAXSTEPS = 1000\n", | |
"for step in range(MAXSTEPS + 1):\n", | |
" (_,mse) = sess.run([optimizer,MSE])\n", | |
" if (step % (MAXSTEPS / 10)) == 0:\n", | |
" print('step = %-5d MSE = %-10f' % (step,mse))\n", | |
"\n", | |
"print('done !')\n", | |
"\n", | |
"tf_output_pred = sess.run(y_pred)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf_mse = ((outputs - tf_output_pred)**2).mean()\n", | |
"tf_std = tf_output_pred.std()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## (b)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn.metrics import mean_squared_error\n", | |
"from sklearn.metrics import accuracy_score\n", | |
"LgR = LogisticRegression(solver='lbfgs', multi_class='multinomial')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", | |
" intercept_scaling=1, max_iter=100, multi_class='multinomial',\n", | |
" n_jobs=1, penalty='l2', random_state=None, solver='lbfgs',\n", | |
" tol=0.0001, verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"LgR.fit(inputs,train['train_labels'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sk_outputs_pred = LgR.predict_proba(inputs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# check Sklearn mean square error\n", | |
"sk_mse = ((outputs - sk_outputs_pred)**2).mean()\n", | |
"sk_std = sk_outputs_pred.std()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## (c)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tf mse: 0.00139507375488\n", | |
"sk mse: 0.0192015802169\n", | |
"\n", | |
"tf rmse: 0.0373506861366\n", | |
"sk rmse: 0.138569766605\n" | |
] | |
} | |
], | |
"source": [ | |
"print('tf mse: ', tf_mse)\n", | |
"print('sk mse: ', sk_mse)\n", | |
"print()\n", | |
"print('tf rmse: ', np.sqrt(tf_mse))\n", | |
"print('sk rmse: ', np.sqrt(sk_mse))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## (d)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensorflow error rate = 0.0\n", | |
"sklearn error rate = 0.0\n" | |
] | |
} | |
], | |
"source": [ | |
"tf_acc = (tf_output_pred.argmax(axis=1) == outputs.argmax(axis=1)).sum() / outputs.shape[0]\n", | |
"sk_acc = (sk_outputs_pred.argmax(axis=1) == outputs.argmax(axis=1)).sum() / outputs.shape[0]\n", | |
"\n", | |
"print('tensorflow error rate = ', 1 - tf_acc)\n", | |
"print('sklearn error rate = ', 1 - sk_acc)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## (e)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['test_images', 'test_labels']" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test = np.load('../data/MNIST_test_100.npz')\n", | |
"test.keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"test_in = test['test_images'].reshape(100, 784)\n", | |
"test_out = pd.get_dummies(test['test_labels']).iloc[:,:].values" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf_test_outputs_pred = sess.run(y_pred, feed_dict={x: test_in})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sk_test_outputs_pred = LgR.predict_proba(test_in)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf_test_mse = ((test_out - tf_test_outputs_pred)**2).mean()\n", | |
"tf_test_std = tf_test_outputs_pred.std()\n", | |
"\n", | |
"sk_test_mse = ((test_out - sk_test_outputs_pred)**2).mean()\n", | |
"sk_test_std = sk_test_outputs_pred.std()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tf mse: 0.0831381725287\n", | |
"sk mse: 0.353527929652\n", | |
"\n", | |
"tf rmse: 0.288336908024\n", | |
"sk rmse: 0.594582147102\n" | |
] | |
} | |
], | |
"source": [ | |
"print('tf mse: ', tf_test_mse)\n", | |
"print('sk mse: ', sk_test_std)\n", | |
"print()\n", | |
"print('tf rmse: ', np.sqrt(tf_test_mse))\n", | |
"print('sk rmse: ', np.sqrt(sk_test_std))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensorflow error rate = 0.11\n", | |
"sklearn error rate = 0.11\n" | |
] | |
} | |
], | |
"source": [ | |
"tf_test_acc = (tf_test_outputs_pred.argmax(axis=1) == test_out.argmax(axis=1)).sum() / test_out.shape[0]\n", | |
"sk_test_acc = (sk_test_outputs_pred.argmax(axis=1) == test_out.argmax(axis=1)).sum() / test_out.shape[0]\n", | |
"\n", | |
"print('tensorflow error rate = ', 1 - tf_test_acc)\n", | |
"print('sklearn error rate = ', 1 - sk_test_acc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment