Skip to content

Instantly share code, notes, and snippets.

@yungwarlock
Created July 2, 2023 01:09
Show Gist options
  • Save yungwarlock/ec4eb19cd66a578505d289aa9ea3f6ab to your computer and use it in GitHub Desktop.
Save yungwarlock/ec4eb19cd66a578505d289aa9ea3f6ab to your computer and use it in GitHub Desktop.
nearest_neighbor.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/yungwarlock/ec4eb19cd66a578505d289aa9ea3f6ab/nearest_neighbor.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9SRExWFUH_LO"
},
"source": [
"# Nearest Neighbor in TensorFlow\n",
"\n",
"Credits: Forked from [TensorFlow-Examples](https://github.com/aymericdamien/TensorFlow-Examples) by Aymeric Damien\n",
"\n",
"## Setup\n",
"\n",
"Refer to the [setup instructions](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/Setup_TensorFlow.md)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"id": "a7oz8PbvH_LR"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 339
},
"id": "vvsgXRebH_LT",
"outputId": "c5248036-8210-4636-e392-2fd0117b5424"
},
"outputs": [
{
"output_type": "error",
"ename": "ModuleNotFoundError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-98e119d70156>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Import MINST data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mmnist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/data/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mone_hot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'input_data'",
"",
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
],
"errorDetails": {
"actions": [
{
"action": "open_url",
"actionText": "Open Examples",
"url": "/notebooks/snippets/importing_libraries.ipynb"
}
]
}
}
],
"source": [
"# Import MINST data\n",
"import input_data\n",
"mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "dY1arItKH_LT"
},
"outputs": [],
"source": [
"# In this example, we limit mnist data\n",
"Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)\n",
"Xte, Yte = mnist.test.next_batch(200) #200 for testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "cVLxrklJH_LU"
},
"outputs": [],
"source": [
"# Reshape images to 1D\n",
"Xtr = np.reshape(Xtr, newshape=(-1, 28*28))\n",
"Xte = np.reshape(Xte, newshape=(-1, 28*28))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "MMgEdtEOH_LU"
},
"outputs": [],
"source": [
"# tf Graph Input\n",
"xtr = tf.placeholder(\"float\", [None, 784])\n",
"xte = tf.placeholder(\"float\", [784])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "UqcEs5TjH_LU"
},
"outputs": [],
"source": [
"# Nearest Neighbor calculation using L1 Distance\n",
"# Calculate L1 Distance\n",
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n",
"# Predict: Get min distance index (Nearest neighbor)\n",
"pred = tf.arg_min(distance, 0)\n",
"\n",
"accuracy = 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "mVH_UBJZH_LV"
},
"outputs": [],
"source": [
"# Initializing the variables\n",
"init = tf.initialize_all_variables()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fWMFkGsZH_LV",
"outputId": "65b9d593-fc29-4964-ba6d-690f3728e418"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test 0 Prediction: 7 True Class: 7\n",
"Test 1 Prediction: 2 True Class: 2\n",
"Test 2 Prediction: 1 True Class: 1\n",
"Test 3 Prediction: 0 True Class: 0\n",
"Test 4 Prediction: 4 True Class: 4\n",
"Test 5 Prediction: 1 True Class: 1\n",
"Test 6 Prediction: 4 True Class: 4\n",
"Test 7 Prediction: 9 True Class: 9\n",
"Test 8 Prediction: 8 True Class: 5\n",
"Test 9 Prediction: 9 True Class: 9\n",
"Test 10 Prediction: 0 True Class: 0\n",
"Test 11 Prediction: 0 True Class: 6\n",
"Test 12 Prediction: 9 True Class: 9\n",
"Test 13 Prediction: 0 True Class: 0\n",
"Test 14 Prediction: 1 True Class: 1\n",
"Test 15 Prediction: 5 True Class: 5\n",
"Test 16 Prediction: 4 True Class: 9\n",
"Test 17 Prediction: 7 True Class: 7\n",
"Test 18 Prediction: 3 True Class: 3\n",
"Test 19 Prediction: 4 True Class: 4\n",
"Test 20 Prediction: 9 True Class: 9\n",
"Test 21 Prediction: 6 True Class: 6\n",
"Test 22 Prediction: 6 True Class: 6\n",
"Test 23 Prediction: 5 True Class: 5\n",
"Test 24 Prediction: 4 True Class: 4\n",
"Test 25 Prediction: 0 True Class: 0\n",
"Test 26 Prediction: 7 True Class: 7\n",
"Test 27 Prediction: 4 True Class: 4\n",
"Test 28 Prediction: 0 True Class: 0\n",
"Test 29 Prediction: 1 True Class: 1\n",
"Test 30 Prediction: 3 True Class: 3\n",
"Test 31 Prediction: 1 True Class: 1\n",
"Test 32 Prediction: 3 True Class: 3\n",
"Test 33 Prediction: 4 True Class: 4\n",
"Test 34 Prediction: 7 True Class: 7\n",
"Test 35 Prediction: 2 True Class: 2\n",
"Test 36 Prediction: 7 True Class: 7\n",
"Test 37 Prediction: 1 True Class: 1\n",
"Test 38 Prediction: 2 True Class: 2\n",
"Test 39 Prediction: 1 True Class: 1\n",
"Test 40 Prediction: 1 True Class: 1\n",
"Test 41 Prediction: 7 True Class: 7\n",
"Test 42 Prediction: 4 True Class: 4\n",
"Test 43 Prediction: 1 True Class: 2\n",
"Test 44 Prediction: 3 True Class: 3\n",
"Test 45 Prediction: 5 True Class: 5\n",
"Test 46 Prediction: 1 True Class: 1\n",
"Test 47 Prediction: 2 True Class: 2\n",
"Test 48 Prediction: 4 True Class: 4\n",
"Test 49 Prediction: 4 True Class: 4\n",
"Test 50 Prediction: 6 True Class: 6\n",
"Test 51 Prediction: 3 True Class: 3\n",
"Test 52 Prediction: 5 True Class: 5\n",
"Test 53 Prediction: 5 True Class: 5\n",
"Test 54 Prediction: 6 True Class: 6\n",
"Test 55 Prediction: 0 True Class: 0\n",
"Test 56 Prediction: 4 True Class: 4\n",
"Test 57 Prediction: 1 True Class: 1\n",
"Test 58 Prediction: 9 True Class: 9\n",
"Test 59 Prediction: 5 True Class: 5\n",
"Test 60 Prediction: 7 True Class: 7\n",
"Test 61 Prediction: 8 True Class: 8\n",
"Test 62 Prediction: 9 True Class: 9\n",
"Test 63 Prediction: 3 True Class: 3\n",
"Test 64 Prediction: 7 True Class: 7\n",
"Test 65 Prediction: 4 True Class: 4\n",
"Test 66 Prediction: 6 True Class: 6\n",
"Test 67 Prediction: 4 True Class: 4\n",
"Test 68 Prediction: 3 True Class: 3\n",
"Test 69 Prediction: 0 True Class: 0\n",
"Test 70 Prediction: 7 True Class: 7\n",
"Test 71 Prediction: 0 True Class: 0\n",
"Test 72 Prediction: 2 True Class: 2\n",
"Test 73 Prediction: 7 True Class: 9\n",
"Test 74 Prediction: 1 True Class: 1\n",
"Test 75 Prediction: 7 True Class: 7\n",
"Test 76 Prediction: 3 True Class: 3\n",
"Test 77 Prediction: 7 True Class: 2\n",
"Test 78 Prediction: 9 True Class: 9\n",
"Test 79 Prediction: 7 True Class: 7\n",
"Test 80 Prediction: 7 True Class: 7\n",
"Test 81 Prediction: 6 True Class: 6\n",
"Test 82 Prediction: 2 True Class: 2\n",
"Test 83 Prediction: 7 True Class: 7\n",
"Test 84 Prediction: 8 True Class: 8\n",
"Test 85 Prediction: 4 True Class: 4\n",
"Test 86 Prediction: 7 True Class: 7\n",
"Test 87 Prediction: 3 True Class: 3\n",
"Test 88 Prediction: 6 True Class: 6\n",
"Test 89 Prediction: 1 True Class: 1\n",
"Test 90 Prediction: 3 True Class: 3\n",
"Test 91 Prediction: 6 True Class: 6\n",
"Test 92 Prediction: 9 True Class: 9\n",
"Test 93 Prediction: 3 True Class: 3\n",
"Test 94 Prediction: 1 True Class: 1\n",
"Test 95 Prediction: 4 True Class: 4\n",
"Test 96 Prediction: 1 True Class: 1\n",
"Test 97 Prediction: 7 True Class: 7\n",
"Test 98 Prediction: 6 True Class: 6\n",
"Test 99 Prediction: 9 True Class: 9\n",
"Test 100 Prediction: 6 True Class: 6\n",
"Test 101 Prediction: 0 True Class: 0\n",
"Test 102 Prediction: 5 True Class: 5\n",
"Test 103 Prediction: 4 True Class: 4\n",
"Test 104 Prediction: 9 True Class: 9\n",
"Test 105 Prediction: 9 True Class: 9\n",
"Test 106 Prediction: 2 True Class: 2\n",
"Test 107 Prediction: 1 True Class: 1\n",
"Test 108 Prediction: 9 True Class: 9\n",
"Test 109 Prediction: 4 True Class: 4\n",
"Test 110 Prediction: 8 True Class: 8\n",
"Test 111 Prediction: 7 True Class: 7\n",
"Test 112 Prediction: 3 True Class: 3\n",
"Test 113 Prediction: 9 True Class: 9\n",
"Test 114 Prediction: 7 True Class: 7\n",
"Test 115 Prediction: 9 True Class: 4\n",
"Test 116 Prediction: 9 True Class: 4\n",
"Test 117 Prediction: 4 True Class: 4\n",
"Test 118 Prediction: 9 True Class: 9\n",
"Test 119 Prediction: 7 True Class: 2\n",
"Test 120 Prediction: 5 True Class: 5\n",
"Test 121 Prediction: 4 True Class: 4\n",
"Test 122 Prediction: 7 True Class: 7\n",
"Test 123 Prediction: 6 True Class: 6\n",
"Test 124 Prediction: 7 True Class: 7\n",
"Test 125 Prediction: 9 True Class: 9\n",
"Test 126 Prediction: 0 True Class: 0\n",
"Test 127 Prediction: 5 True Class: 5\n",
"Test 128 Prediction: 8 True Class: 8\n",
"Test 129 Prediction: 5 True Class: 5\n",
"Test 130 Prediction: 6 True Class: 6\n",
"Test 131 Prediction: 6 True Class: 6\n",
"Test 132 Prediction: 5 True Class: 5\n",
"Test 133 Prediction: 7 True Class: 7\n",
"Test 134 Prediction: 8 True Class: 8\n",
"Test 135 Prediction: 1 True Class: 1\n",
"Test 136 Prediction: 0 True Class: 0\n",
"Test 137 Prediction: 1 True Class: 1\n",
"Test 138 Prediction: 6 True Class: 6\n",
"Test 139 Prediction: 4 True Class: 4\n",
"Test 140 Prediction: 6 True Class: 6\n",
"Test 141 Prediction: 7 True Class: 7\n",
"Test 142 Prediction: 2 True Class: 3\n",
"Test 143 Prediction: 1 True Class: 1\n",
"Test 144 Prediction: 7 True Class: 7\n",
"Test 145 Prediction: 1 True Class: 1\n",
"Test 146 Prediction: 8 True Class: 8\n",
"Test 147 Prediction: 2 True Class: 2\n",
"Test 148 Prediction: 0 True Class: 0\n",
"Test 149 Prediction: 1 True Class: 2\n",
"Test 150 Prediction: 9 True Class: 9\n",
"Test 151 Prediction: 9 True Class: 9\n",
"Test 152 Prediction: 5 True Class: 5\n",
"Test 153 Prediction: 5 True Class: 5\n",
"Test 154 Prediction: 1 True Class: 1\n",
"Test 155 Prediction: 5 True Class: 5\n",
"Test 156 Prediction: 6 True Class: 6\n",
"Test 157 Prediction: 0 True Class: 0\n",
"Test 158 Prediction: 3 True Class: 3\n",
"Test 159 Prediction: 4 True Class: 4\n",
"Test 160 Prediction: 4 True Class: 4\n",
"Test 161 Prediction: 6 True Class: 6\n",
"Test 162 Prediction: 5 True Class: 5\n",
"Test 163 Prediction: 4 True Class: 4\n",
"Test 164 Prediction: 6 True Class: 6\n",
"Test 165 Prediction: 5 True Class: 5\n",
"Test 166 Prediction: 4 True Class: 4\n",
"Test 167 Prediction: 5 True Class: 5\n",
"Test 168 Prediction: 1 True Class: 1\n",
"Test 169 Prediction: 4 True Class: 4\n",
"Test 170 Prediction: 9 True Class: 4\n",
"Test 171 Prediction: 7 True Class: 7\n",
"Test 172 Prediction: 2 True Class: 2\n",
"Test 173 Prediction: 3 True Class: 3\n",
"Test 174 Prediction: 2 True Class: 2\n",
"Test 175 Prediction: 1 True Class: 7\n",
"Test 176 Prediction: 1 True Class: 1\n",
"Test 177 Prediction: 8 True Class: 8\n",
"Test 178 Prediction: 1 True Class: 1\n",
"Test 179 Prediction: 8 True Class: 8\n",
"Test 180 Prediction: 1 True Class: 1\n",
"Test 181 Prediction: 8 True Class: 8\n",
"Test 182 Prediction: 5 True Class: 5\n",
"Test 183 Prediction: 0 True Class: 0\n",
"Test 184 Prediction: 2 True Class: 8\n",
"Test 185 Prediction: 9 True Class: 9\n",
"Test 186 Prediction: 2 True Class: 2\n",
"Test 187 Prediction: 5 True Class: 5\n",
"Test 188 Prediction: 0 True Class: 0\n",
"Test 189 Prediction: 1 True Class: 1\n",
"Test 190 Prediction: 1 True Class: 1\n",
"Test 191 Prediction: 1 True Class: 1\n",
"Test 192 Prediction: 0 True Class: 0\n",
"Test 193 Prediction: 4 True Class: 9\n",
"Test 194 Prediction: 0 True Class: 0\n",
"Test 195 Prediction: 1 True Class: 3\n",
"Test 196 Prediction: 1 True Class: 1\n",
"Test 197 Prediction: 6 True Class: 6\n",
"Test 198 Prediction: 4 True Class: 4\n",
"Test 199 Prediction: 2 True Class: 2\n",
"Done!\n",
"Accuracy: 0.92\n"
]
}
],
"source": [
"# Launch the graph\n",
"with tf.Session() as sess:\n",
" sess.run(init)\n",
"\n",
" # loop over test data\n",
" for i in range(len(Xte)):\n",
" # Get nearest neighbor\n",
" nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})\n",
" # Get nearest neighbor class label and compare it to its true label\n",
" print \"Test\", i, \"Prediction:\", np.argmax(Ytr[nn_index]), \\\n",
" \"True Class:\", np.argmax(Yte[i])\n",
" # Calculate accuracy\n",
" if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):\n",
" accuracy += 1./len(Xte)\n",
" print \"Done!\"\n",
" print \"Accuracy:\", accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "p65iJWWAH_LW"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment