Last active
May 26, 2017 09:35
-
-
Save erogol/da4a7b0f729d6c309ba21d7f872d5708 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import print_function | |
import numpy as np | |
from keras.models import Sequential, Model | |
from keras.layers import Dense, Dropout, Lambda, merge, BatchNormalization, Activation, Input, Merge | |
from keras import backend as K | |
def euclidean_distance(vects): | |
x, y = vects | |
return K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True)) | |
def eucl_dist_output_shape(shapes): | |
shape1, shape2 = shapes | |
return (shape1[0], 1) | |
def cosine_distance(vests): | |
x, y = vests | |
x = K.l2_normalize(x, axis=-1) | |
y = K.l2_normalize(y, axis=-1) | |
return -K.mean(x * y, axis=-1, keepdims=True) | |
def cos_dist_output_shape(shapes): | |
shape1, shape2 = shapes | |
return (shape1[0],1) | |
def contrastive_loss(y_true, y_pred): | |
'''Contrastive loss from Hadsell-et-al.'06 | |
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | |
''' | |
margin = 1 | |
return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0))) | |
def create_base_network(input_dim): | |
''' | |
Base network for feature extraction. | |
''' | |
input = Input(shape=(input_dim, )) | |
dense1 = Dense(128)(input) | |
bn1 = BatchNormalization(mode=2)(dense1) | |
relu1 = Activation('relu')(bn1) | |
dense2 = Dense(128)(relu1) | |
bn2 = BatchNormalization(mode=2)(dense2) | |
res2 = merge([relu1, bn2], mode='sum') | |
relu2 = Activation('relu')(res2) | |
dense3 = Dense(128)(relu2) | |
bn3 = BatchNormalization(mode=2)(dense3) | |
res3 = Merge(mode='sum')([relu2, bn3]) | |
relu3 = Activation('relu')(res3) | |
feats = merge([relu3, relu2, relu1], mode='concat') | |
bn4 = BatchNormalization(mode=2)(feats) | |
model = Model(input=input, output=bn4) | |
return model | |
def compute_accuracy(predictions, labels): | |
''' | |
Compute classification accuracy with a fixed threshold on distances. | |
''' | |
return np.mean(np.equal(predictions.ravel() < 0.5, labels)) | |
def create_network(input_dim): | |
# network definition | |
base_network = create_base_network(input_dim) | |
input_a = Input(shape=(input_dim,)) | |
input_b = Input(shape=(input_dim,)) | |
# because we re-use the same instance `base_network`, | |
# the weights of the network | |
# will be shared across the two branches | |
processed_a = base_network(input_a) | |
processed_b = base_network(input_b) | |
distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([processed_a, processed_b]) | |
model = Model(input=[input_a, input_b], output=distance) | |
return model | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment