Skip to content

Instantly share code, notes, and snippets.

Forked from thomasdullien/
Created October 25, 2017 22:33
Show Gist options
  • Save pbamotra/a3b8b2229aeefefe8293650dfe6ddd50 to your computer and use it in GitHub Desktop.
Save pbamotra/a3b8b2229aeefefe8293650dfe6ddd50 to your computer and use it in GitHub Desktop.
Inception for feature extraction, ANNoy for nearest-neighbor search
Simple, hacked-up image similarity search using Tensorflow + the inception
CNN as feature extractor and ANNoy for nearest neighbor search.
Requires Tensorflow and ANNoy.
Based on gist code under
from annoy import AnnoyIndex
import os
import sys
import tensorflow as tf
import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
def create_graph(model_path):
create_graph loads the inception model to memory, should be called before
calling extract_features.
model_path: path to inception model in protobuf form.
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
_ = tf.import_graph_def(graph_def, name='')
def extract_features(image_paths, verbose=False):
extract_features computed the inception bottleneck feature for a list of images
image_paths: array of image path
return: 2-d array in the shape of (len(image_paths), 2048)
feature_dimension = 2048
features = np.empty((len(image_paths), feature_dimension))
with tf.Session() as sess:
flattened_tensor = sess.graph.get_tensor_by_name('pool_3:0')
for i, image_path in enumerate(image_paths):
if verbose:
print('Processing %s...' % (image_path))
if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image)
image_data = gfile.FastGFile(image_path, 'rb').read()
feature =, {
'DecodeJpeg/contents:0': image_data
features[i, :] = np.squeeze(feature)
return features
if sys.argv[1] == "index":
print("[!] Creating a new image similarity search index.")
print("[!] Loading the inception CNN")
print("[!] Done.")
input_path = sys.argv[2]
files = os.listdir(input_path)
images = [ input_path + i for i in files ]
results = extract_features(images, True)
print("[!] Done extracting features, building search index")
ann_index = AnnoyIndex(len(results[0]))
for i in xrange(len(images)):
ann_index.add_item(i, results[i])
print("[!] Constructing trees")
print("[!] Saving the index to '%s'" % sys.argv[3])[3])
print("[!] Saving the filelist to '%s'" % (sys.argv[3] + ".filelist"))
filelist = file(sys.argv[3] + ".filelist", "wt")
elif sys.argv[1] == "search":
print("[!] Searching for similar images.")
print("[!] Loading the inception CNN")
print("[!] Done.")
input_path = sys.argv[2]
files = os.listdir(input_path)
images = [ input_path + i for i in files ]
results = extract_features(images, True)
ann_index = AnnoyIndex(len(results[0]))
filelist = file(sys.argv[3] + ".filelist", "rt").readlines()
for i in xrange(len(results)):
print("[!] Searching for similar images to '%s'" % images[i])
search_results = ann_index.get_nns_by_vector(results[i], 10,
for i in xrange(len(search_results[0])):
print("%f -> %d (%s)" % (search_results[1][i], search_results[0][i],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment