Created
July 2, 2023 01:09
-
-
Save yungwarlock/ec4eb19cd66a578505d289aa9ea3f6ab to your computer and use it in GitHub Desktop.
nearest_neighbor.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/yungwarlock/ec4eb19cd66a578505d289aa9ea3f6ab/nearest_neighbor.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9SRExWFUH_LO" | |
}, | |
"source": [ | |
"# Nearest Neighbor in TensorFlow\n", | |
"\n", | |
"Credits: Forked from [TensorFlow-Examples](https://github.com/aymericdamien/TensorFlow-Examples) by Aymeric Damien\n", | |
"\n", | |
"## Setup\n", | |
"\n", | |
"Refer to the [setup instructions](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/Setup_TensorFlow.md)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true, | |
"id": "a7oz8PbvH_LR" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 339 | |
}, | |
"id": "vvsgXRebH_LT", | |
"outputId": "c5248036-8210-4636-e392-2fd0117b5424" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "ModuleNotFoundError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-2-98e119d70156>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Import MINST data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mmnist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/data/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mone_hot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'input_data'", | |
"", | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n" | |
], | |
"errorDetails": { | |
"actions": [ | |
{ | |
"action": "open_url", | |
"actionText": "Open Examples", | |
"url": "/notebooks/snippets/importing_libraries.ipynb" | |
} | |
] | |
} | |
} | |
], | |
"source": [ | |
"# Import MINST data\n", | |
"import input_data\n", | |
"mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "dY1arItKH_LT" | |
}, | |
"outputs": [], | |
"source": [ | |
"# In this example, we limit mnist data\n", | |
"Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)\n", | |
"Xte, Yte = mnist.test.next_batch(200) #200 for testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "cVLxrklJH_LU" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Reshape images to 1D\n", | |
"Xtr = np.reshape(Xtr, newshape=(-1, 28*28))\n", | |
"Xte = np.reshape(Xte, newshape=(-1, 28*28))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "MMgEdtEOH_LU" | |
}, | |
"outputs": [], | |
"source": [ | |
"# tf Graph Input\n", | |
"xtr = tf.placeholder(\"float\", [None, 784])\n", | |
"xte = tf.placeholder(\"float\", [784])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "UqcEs5TjH_LU" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Nearest Neighbor calculation using L1 Distance\n", | |
"# Calculate L1 Distance\n", | |
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n", | |
"# Predict: Get min distance index (Nearest neighbor)\n", | |
"pred = tf.arg_min(distance, 0)\n", | |
"\n", | |
"accuracy = 0." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "mVH_UBJZH_LV" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Initializing the variables\n", | |
"init = tf.initialize_all_variables()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "fWMFkGsZH_LV", | |
"outputId": "65b9d593-fc29-4964-ba6d-690f3728e418" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test 0 Prediction: 7 True Class: 7\n", | |
"Test 1 Prediction: 2 True Class: 2\n", | |
"Test 2 Prediction: 1 True Class: 1\n", | |
"Test 3 Prediction: 0 True Class: 0\n", | |
"Test 4 Prediction: 4 True Class: 4\n", | |
"Test 5 Prediction: 1 True Class: 1\n", | |
"Test 6 Prediction: 4 True Class: 4\n", | |
"Test 7 Prediction: 9 True Class: 9\n", | |
"Test 8 Prediction: 8 True Class: 5\n", | |
"Test 9 Prediction: 9 True Class: 9\n", | |
"Test 10 Prediction: 0 True Class: 0\n", | |
"Test 11 Prediction: 0 True Class: 6\n", | |
"Test 12 Prediction: 9 True Class: 9\n", | |
"Test 13 Prediction: 0 True Class: 0\n", | |
"Test 14 Prediction: 1 True Class: 1\n", | |
"Test 15 Prediction: 5 True Class: 5\n", | |
"Test 16 Prediction: 4 True Class: 9\n", | |
"Test 17 Prediction: 7 True Class: 7\n", | |
"Test 18 Prediction: 3 True Class: 3\n", | |
"Test 19 Prediction: 4 True Class: 4\n", | |
"Test 20 Prediction: 9 True Class: 9\n", | |
"Test 21 Prediction: 6 True Class: 6\n", | |
"Test 22 Prediction: 6 True Class: 6\n", | |
"Test 23 Prediction: 5 True Class: 5\n", | |
"Test 24 Prediction: 4 True Class: 4\n", | |
"Test 25 Prediction: 0 True Class: 0\n", | |
"Test 26 Prediction: 7 True Class: 7\n", | |
"Test 27 Prediction: 4 True Class: 4\n", | |
"Test 28 Prediction: 0 True Class: 0\n", | |
"Test 29 Prediction: 1 True Class: 1\n", | |
"Test 30 Prediction: 3 True Class: 3\n", | |
"Test 31 Prediction: 1 True Class: 1\n", | |
"Test 32 Prediction: 3 True Class: 3\n", | |
"Test 33 Prediction: 4 True Class: 4\n", | |
"Test 34 Prediction: 7 True Class: 7\n", | |
"Test 35 Prediction: 2 True Class: 2\n", | |
"Test 36 Prediction: 7 True Class: 7\n", | |
"Test 37 Prediction: 1 True Class: 1\n", | |
"Test 38 Prediction: 2 True Class: 2\n", | |
"Test 39 Prediction: 1 True Class: 1\n", | |
"Test 40 Prediction: 1 True Class: 1\n", | |
"Test 41 Prediction: 7 True Class: 7\n", | |
"Test 42 Prediction: 4 True Class: 4\n", | |
"Test 43 Prediction: 1 True Class: 2\n", | |
"Test 44 Prediction: 3 True Class: 3\n", | |
"Test 45 Prediction: 5 True Class: 5\n", | |
"Test 46 Prediction: 1 True Class: 1\n", | |
"Test 47 Prediction: 2 True Class: 2\n", | |
"Test 48 Prediction: 4 True Class: 4\n", | |
"Test 49 Prediction: 4 True Class: 4\n", | |
"Test 50 Prediction: 6 True Class: 6\n", | |
"Test 51 Prediction: 3 True Class: 3\n", | |
"Test 52 Prediction: 5 True Class: 5\n", | |
"Test 53 Prediction: 5 True Class: 5\n", | |
"Test 54 Prediction: 6 True Class: 6\n", | |
"Test 55 Prediction: 0 True Class: 0\n", | |
"Test 56 Prediction: 4 True Class: 4\n", | |
"Test 57 Prediction: 1 True Class: 1\n", | |
"Test 58 Prediction: 9 True Class: 9\n", | |
"Test 59 Prediction: 5 True Class: 5\n", | |
"Test 60 Prediction: 7 True Class: 7\n", | |
"Test 61 Prediction: 8 True Class: 8\n", | |
"Test 62 Prediction: 9 True Class: 9\n", | |
"Test 63 Prediction: 3 True Class: 3\n", | |
"Test 64 Prediction: 7 True Class: 7\n", | |
"Test 65 Prediction: 4 True Class: 4\n", | |
"Test 66 Prediction: 6 True Class: 6\n", | |
"Test 67 Prediction: 4 True Class: 4\n", | |
"Test 68 Prediction: 3 True Class: 3\n", | |
"Test 69 Prediction: 0 True Class: 0\n", | |
"Test 70 Prediction: 7 True Class: 7\n", | |
"Test 71 Prediction: 0 True Class: 0\n", | |
"Test 72 Prediction: 2 True Class: 2\n", | |
"Test 73 Prediction: 7 True Class: 9\n", | |
"Test 74 Prediction: 1 True Class: 1\n", | |
"Test 75 Prediction: 7 True Class: 7\n", | |
"Test 76 Prediction: 3 True Class: 3\n", | |
"Test 77 Prediction: 7 True Class: 2\n", | |
"Test 78 Prediction: 9 True Class: 9\n", | |
"Test 79 Prediction: 7 True Class: 7\n", | |
"Test 80 Prediction: 7 True Class: 7\n", | |
"Test 81 Prediction: 6 True Class: 6\n", | |
"Test 82 Prediction: 2 True Class: 2\n", | |
"Test 83 Prediction: 7 True Class: 7\n", | |
"Test 84 Prediction: 8 True Class: 8\n", | |
"Test 85 Prediction: 4 True Class: 4\n", | |
"Test 86 Prediction: 7 True Class: 7\n", | |
"Test 87 Prediction: 3 True Class: 3\n", | |
"Test 88 Prediction: 6 True Class: 6\n", | |
"Test 89 Prediction: 1 True Class: 1\n", | |
"Test 90 Prediction: 3 True Class: 3\n", | |
"Test 91 Prediction: 6 True Class: 6\n", | |
"Test 92 Prediction: 9 True Class: 9\n", | |
"Test 93 Prediction: 3 True Class: 3\n", | |
"Test 94 Prediction: 1 True Class: 1\n", | |
"Test 95 Prediction: 4 True Class: 4\n", | |
"Test 96 Prediction: 1 True Class: 1\n", | |
"Test 97 Prediction: 7 True Class: 7\n", | |
"Test 98 Prediction: 6 True Class: 6\n", | |
"Test 99 Prediction: 9 True Class: 9\n", | |
"Test 100 Prediction: 6 True Class: 6\n", | |
"Test 101 Prediction: 0 True Class: 0\n", | |
"Test 102 Prediction: 5 True Class: 5\n", | |
"Test 103 Prediction: 4 True Class: 4\n", | |
"Test 104 Prediction: 9 True Class: 9\n", | |
"Test 105 Prediction: 9 True Class: 9\n", | |
"Test 106 Prediction: 2 True Class: 2\n", | |
"Test 107 Prediction: 1 True Class: 1\n", | |
"Test 108 Prediction: 9 True Class: 9\n", | |
"Test 109 Prediction: 4 True Class: 4\n", | |
"Test 110 Prediction: 8 True Class: 8\n", | |
"Test 111 Prediction: 7 True Class: 7\n", | |
"Test 112 Prediction: 3 True Class: 3\n", | |
"Test 113 Prediction: 9 True Class: 9\n", | |
"Test 114 Prediction: 7 True Class: 7\n", | |
"Test 115 Prediction: 9 True Class: 4\n", | |
"Test 116 Prediction: 9 True Class: 4\n", | |
"Test 117 Prediction: 4 True Class: 4\n", | |
"Test 118 Prediction: 9 True Class: 9\n", | |
"Test 119 Prediction: 7 True Class: 2\n", | |
"Test 120 Prediction: 5 True Class: 5\n", | |
"Test 121 Prediction: 4 True Class: 4\n", | |
"Test 122 Prediction: 7 True Class: 7\n", | |
"Test 123 Prediction: 6 True Class: 6\n", | |
"Test 124 Prediction: 7 True Class: 7\n", | |
"Test 125 Prediction: 9 True Class: 9\n", | |
"Test 126 Prediction: 0 True Class: 0\n", | |
"Test 127 Prediction: 5 True Class: 5\n", | |
"Test 128 Prediction: 8 True Class: 8\n", | |
"Test 129 Prediction: 5 True Class: 5\n", | |
"Test 130 Prediction: 6 True Class: 6\n", | |
"Test 131 Prediction: 6 True Class: 6\n", | |
"Test 132 Prediction: 5 True Class: 5\n", | |
"Test 133 Prediction: 7 True Class: 7\n", | |
"Test 134 Prediction: 8 True Class: 8\n", | |
"Test 135 Prediction: 1 True Class: 1\n", | |
"Test 136 Prediction: 0 True Class: 0\n", | |
"Test 137 Prediction: 1 True Class: 1\n", | |
"Test 138 Prediction: 6 True Class: 6\n", | |
"Test 139 Prediction: 4 True Class: 4\n", | |
"Test 140 Prediction: 6 True Class: 6\n", | |
"Test 141 Prediction: 7 True Class: 7\n", | |
"Test 142 Prediction: 2 True Class: 3\n", | |
"Test 143 Prediction: 1 True Class: 1\n", | |
"Test 144 Prediction: 7 True Class: 7\n", | |
"Test 145 Prediction: 1 True Class: 1\n", | |
"Test 146 Prediction: 8 True Class: 8\n", | |
"Test 147 Prediction: 2 True Class: 2\n", | |
"Test 148 Prediction: 0 True Class: 0\n", | |
"Test 149 Prediction: 1 True Class: 2\n", | |
"Test 150 Prediction: 9 True Class: 9\n", | |
"Test 151 Prediction: 9 True Class: 9\n", | |
"Test 152 Prediction: 5 True Class: 5\n", | |
"Test 153 Prediction: 5 True Class: 5\n", | |
"Test 154 Prediction: 1 True Class: 1\n", | |
"Test 155 Prediction: 5 True Class: 5\n", | |
"Test 156 Prediction: 6 True Class: 6\n", | |
"Test 157 Prediction: 0 True Class: 0\n", | |
"Test 158 Prediction: 3 True Class: 3\n", | |
"Test 159 Prediction: 4 True Class: 4\n", | |
"Test 160 Prediction: 4 True Class: 4\n", | |
"Test 161 Prediction: 6 True Class: 6\n", | |
"Test 162 Prediction: 5 True Class: 5\n", | |
"Test 163 Prediction: 4 True Class: 4\n", | |
"Test 164 Prediction: 6 True Class: 6\n", | |
"Test 165 Prediction: 5 True Class: 5\n", | |
"Test 166 Prediction: 4 True Class: 4\n", | |
"Test 167 Prediction: 5 True Class: 5\n", | |
"Test 168 Prediction: 1 True Class: 1\n", | |
"Test 169 Prediction: 4 True Class: 4\n", | |
"Test 170 Prediction: 9 True Class: 4\n", | |
"Test 171 Prediction: 7 True Class: 7\n", | |
"Test 172 Prediction: 2 True Class: 2\n", | |
"Test 173 Prediction: 3 True Class: 3\n", | |
"Test 174 Prediction: 2 True Class: 2\n", | |
"Test 175 Prediction: 1 True Class: 7\n", | |
"Test 176 Prediction: 1 True Class: 1\n", | |
"Test 177 Prediction: 8 True Class: 8\n", | |
"Test 178 Prediction: 1 True Class: 1\n", | |
"Test 179 Prediction: 8 True Class: 8\n", | |
"Test 180 Prediction: 1 True Class: 1\n", | |
"Test 181 Prediction: 8 True Class: 8\n", | |
"Test 182 Prediction: 5 True Class: 5\n", | |
"Test 183 Prediction: 0 True Class: 0\n", | |
"Test 184 Prediction: 2 True Class: 8\n", | |
"Test 185 Prediction: 9 True Class: 9\n", | |
"Test 186 Prediction: 2 True Class: 2\n", | |
"Test 187 Prediction: 5 True Class: 5\n", | |
"Test 188 Prediction: 0 True Class: 0\n", | |
"Test 189 Prediction: 1 True Class: 1\n", | |
"Test 190 Prediction: 1 True Class: 1\n", | |
"Test 191 Prediction: 1 True Class: 1\n", | |
"Test 192 Prediction: 0 True Class: 0\n", | |
"Test 193 Prediction: 4 True Class: 9\n", | |
"Test 194 Prediction: 0 True Class: 0\n", | |
"Test 195 Prediction: 1 True Class: 3\n", | |
"Test 196 Prediction: 1 True Class: 1\n", | |
"Test 197 Prediction: 6 True Class: 6\n", | |
"Test 198 Prediction: 4 True Class: 4\n", | |
"Test 199 Prediction: 2 True Class: 2\n", | |
"Done!\n", | |
"Accuracy: 0.92\n" | |
] | |
} | |
], | |
"source": [ | |
"# Launch the graph\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(init)\n", | |
"\n", | |
" # loop over test data\n", | |
" for i in range(len(Xte)):\n", | |
" # Get nearest neighbor\n", | |
" nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})\n", | |
" # Get nearest neighbor class label and compare it to its true label\n", | |
" print \"Test\", i, \"Prediction:\", np.argmax(Ytr[nn_index]), \\\n", | |
" \"True Class:\", np.argmax(Yte[i])\n", | |
" # Calculate accuracy\n", | |
" if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):\n", | |
" accuracy += 1./len(Xte)\n", | |
" print \"Done!\"\n", | |
" print \"Accuracy:\", accuracy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"id": "p65iJWWAH_LW" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
}, | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment