Skip to content

Instantly share code, notes, and snippets.

@JanSchm
Created June 25, 2020 13:24
Show Gist options
  • Save JanSchm/7b70a08d67d2fdb0ad30102bfa306968 to your computer and use it in GitHub Desktop.
Save JanSchm/7b70a08d67d2fdb0ad30102bfa306968 to your computer and use it in GitHub Desktop.
class SingleAttention(Layer):
def __init__(self, d_k, d_v):
super(SingleAttention, self).__init__()
self.d_k = d_k
self.d_v = d_v
def build(self, input_shape):
self.query = Dense(self.d_k, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform')
self.key = Dense(self.d_k, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform')
self.value = Dense(self.d_v, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform')
def call(self, inputs): # inputs = (in_seq, in_seq, in_seq)
q = self.query(inputs[0])
k = self.key(inputs[1])
attn_weights = tf.matmul(q, k, transpose_b=True)
attn_weights = tf.map_fn(lambda x: x/np.sqrt(self.d_k), attn_weights)
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
v = self.value(inputs[2])
attn_out = tf.matmul(attn_weights, v)
return attn_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment