Created
August 11, 2023 18:44
-
-
Save ydennisy/fec55fab84d107b72852ba2d2c2b61db to your computer and use it in GitHub Desktop.
A siamese network for text embedding.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from keras_nlp.layers import TransformerDecoder | |
MAX_LEN, VOCAB_SIZE, EMBED_DIMS = 128, 128, 32 | |
class TokenAndPositionEmbedding(tf.keras.layers.Layer): | |
def __init__(self, maxlen, vocab_size, embed_dim): | |
super().__init__() | |
self.token_emb = tf.keras.layers.Embedding( | |
input_dim=vocab_size, output_dim=embed_dim | |
) | |
self.pos_emb = tf.keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim) | |
def call(self, x): | |
maxlen = tf.shape(x)[-1] | |
positions = tf.range(start=0, limit=maxlen, delta=1) | |
positions = self.pos_emb(positions) | |
x = self.token_emb(x) | |
return x + positions | |
def compute_similarity_matrix(embeddings_1, embeddings_2): | |
similarity_matrix = tf.matmul(embeddings_1, embeddings_2, transpose_b=True) | |
return similarity_matrix | |
loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
@tf.function | |
def multiple_negatives_ranking_loss(y_true, similarity_scores): | |
labels = tf.range(tf.shape(similarity_scores)[0]) | |
return loss_function(labels, similarity_scores) | |
@tf.function | |
def top_k_accuracy(y_true, similarity_scores, k=1): | |
top_k_indices = tf.math.top_k(similarity_scores, k=k).indices | |
correct = tf.reduce_any( | |
tf.equal( | |
top_k_indices, | |
tf.expand_dims(tf.range(tf.shape(similarity_scores)[0]), axis=1), | |
), | |
axis=1, | |
) | |
return tf.reduce_mean(tf.cast(correct, dtype=tf.float32)) | |
@tf.function | |
def mean_reciprocal_rank(y_true, similarity_scores): | |
sorted_indices = tf.argsort(similarity_scores, direction="DESCENDING") | |
rank = tf.where( | |
tf.equal( | |
sorted_indices, | |
tf.expand_dims(tf.range(tf.shape(similarity_scores)[0]), axis=1), | |
) | |
)[:, 1] | |
reciprocal_rank = 1 / (tf.cast(rank, tf.float32) + 1) | |
return tf.reduce_mean(reciprocal_rank) | |
decoder = TransformerDecoder(intermediate_dim=8, num_heads=2) | |
embedding = TokenAndPositionEmbedding(MAX_LEN, VOCAB_SIZE, EMBED_DIMS) | |
inputs = tf.keras.Input(shape=(MAX_LEN,)) | |
x = embedding(inputs) | |
x = decoder(x) | |
outputs = tf.keras.layers.GlobalAveragePooling1D()(x) | |
embedding_model = tf.keras.Model(inputs, outputs, name="embedding_model") | |
inputs_1 = tf.keras.Input(shape=(MAX_LEN,), name="query_input") | |
inputs_2 = tf.keras.Input(shape=(MAX_LEN,), name="text_input") | |
tower_1 = embedding_model(inputs_1) | |
tower_2 = embedding_model(inputs_2) | |
similarity_matrix = compute_similarity_matrix(tower_1, tower_2) | |
model = tf.keras.Model(inputs=[inputs_1, inputs_2], outputs=similarity_matrix) | |
model.compile( | |
loss=multiple_negatives_ranking_loss, metrics=[top_k_accuracy, mean_reciprocal_rank] | |
) | |
model.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment