Created
June 25, 2020 13:24
-
-
Save JanSchm/7b70a08d67d2fdb0ad30102bfa306968 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SingleAttention(Layer): | |
def __init__(self, d_k, d_v): | |
super(SingleAttention, self).__init__() | |
self.d_k = d_k | |
self.d_v = d_v | |
def build(self, input_shape): | |
self.query = Dense(self.d_k, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform') | |
self.key = Dense(self.d_k, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform') | |
self.value = Dense(self.d_v, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform') | |
def call(self, inputs): # inputs = (in_seq, in_seq, in_seq) | |
q = self.query(inputs[0]) | |
k = self.key(inputs[1]) | |
attn_weights = tf.matmul(q, k, transpose_b=True) | |
attn_weights = tf.map_fn(lambda x: x/np.sqrt(self.d_k), attn_weights) | |
attn_weights = tf.nn.softmax(attn_weights, axis=-1) | |
v = self.value(inputs[2]) | |
attn_out = tf.matmul(attn_weights, v) | |
return attn_out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment