Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Last active June 2, 2020 00:14
Show Gist options
  • Save cgarciae/444987292cb09a014cc3bd16ca4fe75a to your computer and use it in GitHub Desktop.
Save cgarciae/444987292cb09a014cc3bd16ca4fe75a to your computer and use it in GitHub Desktop.
Tablular Attention
def get_model(params) -> tf.keras.Model:
x0 = tf.keras.Input(shape=(1,), name="x0")
x1 = tf.keras.Input(shape=(1,), name="x1")
inputs = [x0, x1]
# x0 embeddings
x0 = tf.keras.layers.Dense(10, activation="relu")(x0)
x0 = x0[:, None, :]
x1 = tf.keras.layers.Dense(10, activation="relu")(x1)
x1 = x1[:, None, :]
x = tf.concat([x0, x1], axis=1)
x = AddPositionalEmbeddings()(x)
x = SelfAttentionBlock(10, head_size=10, num_heads=8)(x)
x = SelfAttentionBlock(10, head_size=10, num_heads=8)(x)
x = AttentionPooling(10, n_queries=1, head_size=5, num_heads=8)(x)
x = x[:, 0]
# x = tf.keras.layers.GlobalAveragePooling1D()(x)
x = tf.keras.layers.Dense(1, activation="sigmoid", name="y")(x)
model = tf.keras.Model(inputs=inputs, outputs=x, name="tabular_attention")
return model
class AddPositionalEmbeddings(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.embeddings: tp.Optional[tf.Variable] = None
def build(self, input_shape):
input_shape = list(input_shape)
self.embeddings = self.add_weight(
name="key_kernel", shape=[1] + input_shape[1:]
)
super().build(input_shape)
def call(self, inputs):
return inputs + self.embeddings
class SelfAttentionBlock(tf.keras.layers.Layer):
def __init__(
self,
output_size: int,
head_size: int = 16,
num_heads: int = 3,
dropout: float = 0.0,
activation: tp.Union["str", tp.Callable] = "relu",
**kwargs
):
super().__init__(**kwargs)
self.mha = tfa.layers.MultiHeadAttention(
head_size=head_size, num_heads=num_heads, dropout=dropout
)
self.dense = tf.keras.layers.Dense(output_size, activation=activation)
def call(self, inputs):
x = self.mha([inputs, inputs])
x = self.dense(x)
return x
class AttentionPooling(tf.keras.layers.Layer):
def __init__(
self,
output_size: int,
n_queries: int,
head_size: int = 16,
num_heads: int = 3,
dropout: float = 0.0,
activation: tp.Union["str", tp.Callable] = "relu",
**kwargs
):
super().__init__(**kwargs)
self.n_queries = n_queries
self.mha = tfa.layers.MultiHeadAttention(
head_size=head_size, num_heads=num_heads, dropout=dropout
)
self.dense = tf.keras.layers.Dense(output_size, activation=activation)
self.query: tp.Optional[tf.Variable] = None
def build(self, input_shape):
num_features = input_shape[-1]
self.query = self.add_weight(
name="key_kernel", shape=[1, self.n_queries, num_features]
)
super().build(input_shape)
def call(self, inputs):
query = tf.tile(
self.query, [tf.shape(inputs)[0]] + [1] * (len(inputs.shape) - 1)
)
x = self.mha([query, inputs])
x = self.dense(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment