Skip to content

Instantly share code, notes, and snippets.

@braingineer
Last active April 14, 2016 16:24
Show Gist options
  • Save braingineer/85eaeb8f117ea862f96867e7051106d4 to your computer and use it in GitHub Desktop.
Save braingineer/85eaeb8f117ea862f96867e7051106d4 to your computer and use it in GitHub Desktop.
minimum working example for how an embedding re-use can mess things up
from __future__ import print_function
from keras.layers import Layer, Input, Embedding, RepeatVector, Flatten, Dense
from keras.layers import TimeDistributed as Distribute
from keras.activations import softmax
from keras.engine import merge, InputSpec
import keras.backend as K
class ProbabilityTensor(Layer):
""" function for turning 3d tensor to 2d probability matrix """
def __init__(self, *args, **kwargs):
self.input_spec = [InputSpec(ndim=3)]
self.p_func = Dense(1)
super(ProbabilityTensor, self).__init__(*args, **kwargs)
def get_output_shape_for(self, input_shape):
# b,n,f -> b,n
# s.t. \sum_n n = 1
return (input_shape[0], input_shape[1])
def call(self, x, mask=None):
energy = K.squeeze(Distribute(self.p_func)(x), 2)
return softmax(energy)
def get_config(self):
config = {}
base_config = super(ProbabilityTensor, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SoftAttention(ProbabilityTensor):
def get_output_shape_for(self, input_shape):
# b,n,f -> b,f where f is weighted features summed across n
return (input_shape[0], input_shape[2])
def call(self, x, mask=None):
# b,n,f -> b,f via b,n broadcasted
p_vectors = K.expand_dims(super(SoftAttention, self).call(x, mask), 2)
expanded_p = K.repeat_elements(p_vectors, K.shape(x)[2], axis=2)
return K.sum(expanded_p * x, axis=1)
batch = 64
child_size = 20
vocab_size = 1000
embedding_size = 50
broadcast_n = 30
child_in_shape = (batch, child_size)
parent_in_shape = (batch, 1)
child_in = Input(batch_shape=child_in_shape, name='child_input', dtype='int32')
parent_in = Input(batch_shape=parent_in_shape, name='parent_input', dtype='int32')
F_embed = Embedding(input_dim=vocab_size, output_dim=embedding_size)
F_attend = SoftAttention()
F_repeat = RepeatVector(broadcast_n)
F_flatten = Flatten()
### children
print("Child shape before embedding: ", child_in._keras_shape)
embedded_child = F_embed(child_in)
print("Child shape after embeddding & before attending: ", embedded_child._keras_shape)
attended_child = F_attend(embedded_child)
print("Child shape after attending & before repeating: ", attended_child._keras_shape)
repeated_child = F_repeat(attended_child)
print("Final child shape after repeating: ", repeated_child._keras_shape)
### parent
print("Parent shape before embedding: ", parent_in._keras_shape)
embedded_parent = F_embed(parent_in)
# desired behavior
print("Parent shape after embedding and before flattening: ", embedded_parent._keras_shape)
flat_parent = F_flatten(embedded_parent)
print("Parent shape after flattening before repeating: ", flat_parent._keras_shape)
repeated_parent = F_repeat(flat_parent)
print("Final parent shape after repeating: ", repeated_parent._keras_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment