Last active
June 2, 2020 00:14
-
-
Save cgarciae/444987292cb09a014cc3bd16ca4fe75a to your computer and use it in GitHub Desktop.
Tablular Attention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_model(params) -> tf.keras.Model: | |
x0 = tf.keras.Input(shape=(1,), name="x0") | |
x1 = tf.keras.Input(shape=(1,), name="x1") | |
inputs = [x0, x1] | |
# x0 embeddings | |
x0 = tf.keras.layers.Dense(10, activation="relu")(x0) | |
x0 = x0[:, None, :] | |
x1 = tf.keras.layers.Dense(10, activation="relu")(x1) | |
x1 = x1[:, None, :] | |
x = tf.concat([x0, x1], axis=1) | |
x = AddPositionalEmbeddings()(x) | |
x = SelfAttentionBlock(10, head_size=10, num_heads=8)(x) | |
x = SelfAttentionBlock(10, head_size=10, num_heads=8)(x) | |
x = AttentionPooling(10, n_queries=1, head_size=5, num_heads=8)(x) | |
x = x[:, 0] | |
# x = tf.keras.layers.GlobalAveragePooling1D()(x) | |
x = tf.keras.layers.Dense(1, activation="sigmoid", name="y")(x) | |
model = tf.keras.Model(inputs=inputs, outputs=x, name="tabular_attention") | |
return model | |
class AddPositionalEmbeddings(tf.keras.layers.Layer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.embeddings: tp.Optional[tf.Variable] = None | |
def build(self, input_shape): | |
input_shape = list(input_shape) | |
self.embeddings = self.add_weight( | |
name="key_kernel", shape=[1] + input_shape[1:] | |
) | |
super().build(input_shape) | |
def call(self, inputs): | |
return inputs + self.embeddings | |
class SelfAttentionBlock(tf.keras.layers.Layer): | |
def __init__( | |
self, | |
output_size: int, | |
head_size: int = 16, | |
num_heads: int = 3, | |
dropout: float = 0.0, | |
activation: tp.Union["str", tp.Callable] = "relu", | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.mha = tfa.layers.MultiHeadAttention( | |
head_size=head_size, num_heads=num_heads, dropout=dropout | |
) | |
self.dense = tf.keras.layers.Dense(output_size, activation=activation) | |
def call(self, inputs): | |
x = self.mha([inputs, inputs]) | |
x = self.dense(x) | |
return x | |
class AttentionPooling(tf.keras.layers.Layer): | |
def __init__( | |
self, | |
output_size: int, | |
n_queries: int, | |
head_size: int = 16, | |
num_heads: int = 3, | |
dropout: float = 0.0, | |
activation: tp.Union["str", tp.Callable] = "relu", | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.n_queries = n_queries | |
self.mha = tfa.layers.MultiHeadAttention( | |
head_size=head_size, num_heads=num_heads, dropout=dropout | |
) | |
self.dense = tf.keras.layers.Dense(output_size, activation=activation) | |
self.query: tp.Optional[tf.Variable] = None | |
def build(self, input_shape): | |
num_features = input_shape[-1] | |
self.query = self.add_weight( | |
name="key_kernel", shape=[1, self.n_queries, num_features] | |
) | |
super().build(input_shape) | |
def call(self, inputs): | |
query = tf.tile( | |
self.query, [tf.shape(inputs)[0]] + [1] * (len(inputs.shape) - 1) | |
) | |
x = self.mha([query, inputs]) | |
x = self.dense(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment