-
-
Save gabber12/96723a5bf0eb4192350ed0c6ea297cbc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Model, Sequential | |
from keras.layers import Input, Convolution2D, ZeroPadding2D, MaxPooling2D, Flatten, Dense, Dropout, Activation | |
from PIL import Image,UnidentifiedImageError | |
import numpy as np | |
from keras.preprocessing.image import load_img, save_img, img_to_array | |
from keras.applications.imagenet_utils import preprocess_input | |
from keras.preprocessing import image | |
from keras.models import model_from_json | |
import tensorflow as tf | |
from numpy import dot | |
from numpy.linalg import norm | |
import cv2 | |
def cos_sim(arr1,arr2): | |
cos_sim = dot(arr1, arr2)/(norm(arr1)*norm(arr2)) | |
return cos_sim | |
faceCascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") | |
def crop_face(image): | |
image = cv2.resize(image, (224,224),interpolation = cv2.INTER_AREA) | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
faces = faceCascade.detectMultiScale( | |
gray, | |
scaleFactor=1.3, | |
minNeighbors=3, | |
minSize=(30, 30)) | |
#print(len(faces)) | |
for (x, y, w, h) in faces: | |
#print(x, y, w, h) | |
#cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
cropped_face = image[y:y + h, x:x + w] | |
cropped_face = cv2.resize(cropped_face, (224,224),interpolation = cv2.INTER_AREA) | |
return cropped_face | |
def preprocess_image(cropped_face): | |
#cropped_face = cv2.resize(cropped_face, (224,224),interpolation = cv2.INTER_AREA) | |
img = np.expand_dims(cropped_face, axis=0) | |
img = preprocess_input(img) | |
return img | |
class ErrorCode(object): | |
def __init__(self, status_code, code, message): | |
self.status_code = status_code | |
self.code = code | |
self.message = message | |
error_codes = { | |
"BAD_REQUEST": ErrorCode(400, "BAD_REQUEST", "Check the request and try again"), | |
"BAD_IMAGE_PATH": ErrorCode(400, "BAD_REQUEST", "Invalid image paths, check path and try again"), | |
"UNSUPPORTED_IMAGE_FORMAT": ErrorCode(400, "BAD_REQUEST", "Unsupported image format"), | |
"UNSUPPORTED_OPERATION": ErrorCode(500, "UNSUPPORTED", "Unsupported Request"), | |
"UNKNOWN_EXCEPTION": ErrorCode(500, "UNKNOWN_EXCEPTION", "Something went wrong"), | |
"NOT_FOUND": ErrorCode(404, "NOT_FOUND", "Resource doesnot exist") | |
} | |
class ErrorResponse(object): | |
def __init__(self, code): | |
self.code = code | |
def to_response(self): | |
return jsonify({"success": False, "code": self.code.code, "message": self.code.message}), self.code.status_code | |
from flask import Flask, escape, request, jsonify | |
app = Flask(__name__) | |
class ServerException(Exception): | |
def __init__(self, error_response): | |
self.error_response = error_response | |
@app.errorhandler(404) | |
def page_not_found(e): | |
return ErrorResponse(error_codes['NOT_FOUND']).to_response() | |
@app.errorhandler(400) | |
def page_not_found(e): | |
return ErrorResponse(error_codes['BAD_REQUEST']).to_response() | |
# def preprocess_image(image_path): | |
# try: | |
# print("IMAGE_PATH", image_path) | |
# img = load_img(image_path, target_size=(224, 224)) | |
# img = img_to_array(img) | |
# img = np.expand_dims(img, axis=0) | |
# img = preprocess_input(img) | |
# return img | |
# except FileNotFoundError as e: | |
# raise ServerException(ErrorResponse(error_codes['BAD_IMAGE_PATH'])) | |
# except UnidentifiedImageError as e: | |
# raise ServerException(ErrorResponse(error_codes['UNSUPPORTED_IMAGE_FORMAT'])) | |
class ImageMatcher(object): | |
def __init__(self, similarity_threshold = 0.5): | |
self.similarity_threshold = similarity_threshold | |
def init_model(self): | |
self.model = Sequential() | |
self.model.add(ZeroPadding2D((1,1),input_shape=(224,224, 3))) | |
self.model.add(Convolution2D(64, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(64, (3, 3), activation='relu')) | |
self.model.add(MaxPooling2D((2,2), strides=(2,2))) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(128, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(128, (3, 3), activation='relu')) | |
self.model.add(MaxPooling2D((2,2), strides=(2,2))) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(256, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(256, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(256, (3, 3), activation='relu')) | |
self.model.add(MaxPooling2D((2,2), strides=(2,2))) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(512, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(512, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(512, (3, 3), activation='relu')) | |
self.model.add(MaxPooling2D((2,2), strides=(2,2))) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(512, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(512, (3, 3), activation='relu')) | |
self.model.add(ZeroPadding2D((1,1))) | |
self.model.add(Convolution2D(512, (3, 3), activation='relu')) | |
self.model.add(MaxPooling2D((2,2), strides=(2,2))) | |
self.model.add(Convolution2D(4096, (7, 7), activation='relu')) | |
self.model.add(Dropout(0.5)) | |
self.model.add(Convolution2D(4096, (1, 1), activation='relu')) | |
self.model.add(Dropout(0.5)) | |
self.model.add(Convolution2D(2622, (1, 1))) | |
self.model.add(Flatten()) | |
self.model.add(Activation('softmax')) | |
self.model.load_weights('vgg_face_weights.h5') | |
self.vgg_face_descriptor = Model(inputs=self.model.layers[0].input, outputs=self.model.layers[-2].output) | |
def feature_extraction(self, img_path): | |
img = cv2.imread(img_path) | |
cropped_face = crop_face(img) | |
cropped_face_processed = preprocess_image(cropped_face) | |
features = self.vgg_face_descriptor.predict(cropped_face_processed)[0,:] | |
return features | |
def match_image(self, img_path1, img_path2): | |
print(img_path1, img_path2) | |
similarity = cos_sim(self.feature_extraction(img_path1), self.feature_extraction(img_path2)) | |
return True if similarity > self.similarity_threshold else False, (int(similarity*10000))/100.0 | |
def local_image_path(path): | |
if path.get('type') is None: | |
raise ServerException(ErrorResponse(error_codes['BAD_IMAGE_PATH'])) | |
if path['type'] == 'REMOTE': | |
raise ServerException(ErrorResponse(error_codes["UNSUPPORTED_OPERATION"])) | |
elif path['type'] == 'LOCAL': | |
return path | |
raise ServerException(ErrorResponse(error_codes['BAD_IMAGE_PATH'])) | |
def get_local_image_paths(image_paths): | |
if image_paths is None or len(image_paths) != 2: | |
raise ServerException(ErrorResponse(error_codes["BAD_IMAGE_PATHS"])) | |
return [local_image_path(path) for path in image_paths] | |
import traceback | |
global graph | |
graph = tf.get_default_graph() | |
matcher = ImageMatcher() | |
matcher.init_model() | |
@app.route('/image_matcher', methods=['POST']) | |
def image_matcher(): | |
data = request.get_json() | |
if data is None: | |
return ErrorResponse(error_codes["BAD_REQUEST"]).to_response() | |
try: | |
with graph.as_default(): | |
local_image_paths = get_local_image_paths(data.get('image_paths')) | |
result = matcher.match_image(local_image_paths[0]['path'], local_image_paths[1]['path']) | |
return jsonify({"success": True, "data":[{"id1": local_image_paths[0].get('id'), "id2": local_image_paths[1].get('id'), "result": result[0], "match_percentage": result[1]}]}),200 | |
except ServerException as e: | |
return e.error_response.to_response() | |
except Exception as e: | |
traceback.print_exc() | |
return ErrorResponse(error_codes['UNKNOWN_EXCEPTION']).to_response() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment