Skip to content

Instantly share code, notes, and snippets.

@danFromTelAviv
Forked from iskandr/keras2-highway-network.py
Last active April 24, 2018 13:46
Show Gist options
  • Save danFromTelAviv/ac121b479528a0270ec0faf85fe9a082 to your computer and use it in GitHub Desktop.
Save danFromTelAviv/ac121b479528a0270ec0faf85fe9a082 to your computer and use it in GitHub Desktop.
Since Keras 2.0 removed the Highway Network layer, here's my attempt at implementing something equivalent using the functional API
import keras.backend as K
from keras.layers import Dense, Activation, Multiply, Add, Lambda
import keras.initializers
def highway_layers(value, n_layers, activation="tanh", gate_bias=-3):
dim = K.int_shape(value)[-1]
gate_bias_initializer = keras.initializers.Constant(gate_bias)
for i in range(n_layers):
gate = Dense(units=dim, bias_initializer=gate_bias_initializer)(value)
gate = Activation("sigmoid")(gate)
negated_gate = Lambda(
lambda x: 1.0 - x,
output_shape=(dim,))(gate)
transformed = Dense(units=dim)(value)
transformed = Activation(activation)(transformed)
transformed_gated = Multiply()([gate, transformed])
identity_gated = Multiply()([negated_gate, value])
value = Add()([transformed_gated, identity_gated])
return value
def conv_highway_layer(input, activation="tanh",kernel_size=3, gate_bias=-3):
# y = H(x, W_h)*T(x, W_t) + x*(1-T(x, W_t))
# T(x, W_t, b_t) = sigmoid(W_t'*x+b_t)
# H(x, W_h) = tanh(W_h'*x)
# T = gate, H = regular layer, x = input, y = output
dim = K.int_shape(input)[-1]
gate = Convolution2D(dim, (kernel_size, kernel_size), padding='same', activation='sigmoid',
bias_initializer=keras.initializers.Constant(gate_bias))(input) #T(x, W_t)
H = Convolution2D(dim, (kernel_size, kernel_size), padding='same',
activation=activation)(input) #tanh(W_h'*x)
negated_gate = Lambda(lambda x: 1.0 - x)(gate)
transformed_gated = Multiply()([H, gate]) # H(x, W_h)*T(x, W_t)
identity_gated = Multiply()([input, negated_gate]) #x*(1-T(x, W_t))
output = Add()([transformed_gated, identity_gated]) # H(x, W_h)*T(x, W_t) + x*(1-T(x, W_t))
return output
@danFromTelAviv
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment