Last active
September 26, 2018 17:04
-
-
Save Munawwar/f7817984ba6027f4427168b01cc46e2f to your computer and use it in GitHub Desktop.
Ounass image matching
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
import tensorflow.python.platform | |
from tensorflow.python.platform import gfile | |
import numpy as np | |
from shutil import copyfile | |
#from sklearn import cross_validation, grid_search | |
#from sklearn.metrics import confusion_matrix, classification_report | |
from sklearn.svm import SVC | |
#from sklearn.externals import joblib | |
def create_graph(model_path): | |
""" | |
create_graph loads the inception model to memory, should be called before | |
calling extract_features. | |
model_path: path to inception model in protobuf form. | |
""" | |
with gfile.FastGFile(model_path, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name='') | |
def extract_features(image_paths, verbose=False): | |
""" | |
extract_features computed the inception bottleneck feature for a list of images | |
image_paths: array of image path | |
return: 2-d array in the shape of (len(image_paths), 2048) | |
""" | |
feature_dimension = 2048 | |
features = np.empty((len(image_paths), feature_dimension)) | |
with tf.Session() as sess: | |
flattened_tensor = sess.graph.get_tensor_by_name('pool_3:0') | |
for i, image_path in enumerate(image_paths): | |
if verbose: | |
print('Processing %s...' % (image_path)) | |
if not gfile.Exists(image_path): | |
tf.logging.fatal('File does not exist %s', image) | |
image_data = gfile.FastGFile(image_path, 'rb').read() | |
feature = sess.run(flattened_tensor, { | |
'DecodeJpeg/contents:0': image_data | |
}) | |
features[i, :] = np.squeeze(feature) | |
return features | |
def trainAndTest (features, labels, query_images, trainingPath): | |
classifier = SVC(probability=True, decision_function_shape='ovr') | |
classifier.fit(features, labels) | |
query_features = extract_features(query_images, True) | |
prediction_scores = classifier.decision_function(query_features) | |
print prediction_scores.shape | |
# Find top 5 predictions | |
predictions = [] | |
for i, score_per_class in enumerate(prediction_scores): | |
classes = labels[:] # make a copy, for sorting | |
# simultaneous sort of two array, using keys of first one. | |
score_per_class, sorted_predicted_classes = (list(t) for t in zip(*sorted(zip(score_per_class, classes), reverse=True))) | |
print query_images[i], '->', sorted_predicted_classes[0:5] | |
predictions.append(sorted_predicted_classes[0:5]) | |
outDirectory = './results/' + query_images[i].split('/')[-1] + '/' | |
if not os.path.exists(outDirectory): | |
os.makedirs(outDirectory) | |
copyfile(query_images[i], outDirectory + 'query-image.jpg') | |
for j in range(1,5): | |
copyfile(trainingPath + '/' + sorted_predicted_classes[j], outDirectory + str(j) + '.jpg') | |
return predictions | |
### Start #### | |
#Load pre-trained tensorflow Inception model | |
create_graph('./models/tensorflow_inception_graph.pb') | |
# Find image paths and labels | |
image_paths = [] | |
labels = [] | |
basepath = './images/ounass-training/' | |
for fname in os.listdir(basepath): | |
path = os.path.join(basepath, fname) | |
if os.path.isfile(path): | |
image_paths.append(path) | |
image_paths.sort() | |
for path in image_paths: | |
labels.append(path[len(basepath):]) | |
#print image_paths | |
#print labels | |
features = extract_features(image_paths, True) | |
print features.shape | |
trainAndTest(features, labels, [ | |
'./images/ounass-query/001.jpg', | |
'./images/ounass-query/002.jpg', | |
'./images/ounass-query/003.jpg', | |
'./images/ounass-query/004.jpg', | |
'./images/ounass-query/005.jpg', | |
'./images/ounass-query/006.jpg', | |
'./images/ounass-query/007.jpg', | |
'./images/ounass-query/008.jpg', | |
'./images/ounass-query/009.jpg', | |
'./images/ounass-query/010.jpg' | |
], basepath) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Good matches:

Query image 003.jpg:
Results:

Query image 006.jpg:

Results:
