Created
August 1, 2019 13:12
-
-
Save Chiang97912/49bfe2085fe3e2b5625258a523fe2af2 to your computer and use it in GitHub Desktop.
attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def attention_3d_block(inputs): | |
# inputs.shape = (batch_size, time_steps, input_dim) | |
input_dim = int(inputs.shape[2]) | |
a = Permute((2, 1))(inputs) | |
a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what. | |
a = Dense(TIME_STEPS, activation='softmax')(a) | |
if SINGLE_ATTENTION_VECTOR: | |
a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a) | |
a = RepeatVector(input_dim)(a) | |
a_probs = Permute((2, 1), name='attention_vec')(a) | |
output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul') | |
return output_attention_mul |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment