Skip to content

Instantly share code, notes, and snippets.

@Munawwar
Last active September 26, 2018 17:04
Show Gist options
  • Save Munawwar/f7817984ba6027f4427168b01cc46e2f to your computer and use it in GitHub Desktop.
Save Munawwar/f7817984ba6027f4427168b01cc46e2f to your computer and use it in GitHub Desktop.
Ounass image matching
import os
import tensorflow as tf
import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
from shutil import copyfile
#from sklearn import cross_validation, grid_search
#from sklearn.metrics import confusion_matrix, classification_report
from sklearn.svm import SVC
#from sklearn.externals import joblib
def create_graph(model_path):
"""
create_graph loads the inception model to memory, should be called before
calling extract_features.
model_path: path to inception model in protobuf form.
"""
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def extract_features(image_paths, verbose=False):
"""
extract_features computed the inception bottleneck feature for a list of images
image_paths: array of image path
return: 2-d array in the shape of (len(image_paths), 2048)
"""
feature_dimension = 2048
features = np.empty((len(image_paths), feature_dimension))
with tf.Session() as sess:
flattened_tensor = sess.graph.get_tensor_by_name('pool_3:0')
for i, image_path in enumerate(image_paths):
if verbose:
print('Processing %s...' % (image_path))
if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image)
image_data = gfile.FastGFile(image_path, 'rb').read()
feature = sess.run(flattened_tensor, {
'DecodeJpeg/contents:0': image_data
})
features[i, :] = np.squeeze(feature)
return features
def trainAndTest (features, labels, query_images, trainingPath):
classifier = SVC(probability=True, decision_function_shape='ovr')
classifier.fit(features, labels)
query_features = extract_features(query_images, True)
prediction_scores = classifier.decision_function(query_features)
print prediction_scores.shape
# Find top 5 predictions
predictions = []
for i, score_per_class in enumerate(prediction_scores):
classes = labels[:] # make a copy, for sorting
# simultaneous sort of two array, using keys of first one.
score_per_class, sorted_predicted_classes = (list(t) for t in zip(*sorted(zip(score_per_class, classes), reverse=True)))
print query_images[i], '->', sorted_predicted_classes[0:5]
predictions.append(sorted_predicted_classes[0:5])
outDirectory = './results/' + query_images[i].split('/')[-1] + '/'
if not os.path.exists(outDirectory):
os.makedirs(outDirectory)
copyfile(query_images[i], outDirectory + 'query-image.jpg')
for j in range(1,5):
copyfile(trainingPath + '/' + sorted_predicted_classes[j], outDirectory + str(j) + '.jpg')
return predictions
### Start ####
#Load pre-trained tensorflow Inception model
create_graph('./models/tensorflow_inception_graph.pb')
# Find image paths and labels
image_paths = []
labels = []
basepath = './images/ounass-training/'
for fname in os.listdir(basepath):
path = os.path.join(basepath, fname)
if os.path.isfile(path):
image_paths.append(path)
image_paths.sort()
for path in image_paths:
labels.append(path[len(basepath):])
#print image_paths
#print labels
features = extract_features(image_paths, True)
print features.shape
trainAndTest(features, labels, [
'./images/ounass-query/001.jpg',
'./images/ounass-query/002.jpg',
'./images/ounass-query/003.jpg',
'./images/ounass-query/004.jpg',
'./images/ounass-query/005.jpg',
'./images/ounass-query/006.jpg',
'./images/ounass-query/007.jpg',
'./images/ounass-query/008.jpg',
'./images/ounass-query/009.jpg',
'./images/ounass-query/010.jpg'
], basepath)
@Munawwar
Copy link
Author

Munawwar commented Mar 29, 2017

Good matches:
Query image 003.jpg:
query-image

Results:
1 2 3 4

Query image 006.jpg:
query-image

Results:
1 2 3 4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment