This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scores = tf.keras.layers.Dense(units=1, activation='linear') | |
scores_out = scores(dense_1_out) | |
print(scores_out) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scores_for_softmax = tf.squeeze(scores_out, axis=-1) | |
scores_prob_dist = tf.nn.softmax(scores_for_softmax, axis=-1) | |
print(scores_prob_dist) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
relevance_grades_prob_dist = tf.nn.softmax(relevance_grades, axis=-1) | |
print(relevance_grades_prob_dist) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
loss = tf.keras.losses.KLDivergence() | |
batch_loss = loss(relevance_grades_prob_dist, scores_prob_dist) | |
print(batch_loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
per_example_loss = tf.reduce_sum( | |
relevance_grades_prob_dist * tf.math.log(relevance_grades_prob_dist / scores_prob_dist), | |
axis=-1 | |
) | |
print(per_example_loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
batch_loss = tf.reduce_mean(per_example_loss) | |
print(batch_loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
NUM_DOCS_PER_QUERY = 5 | |
EMBEDDING_DIMS = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ExpandBatchLayer(tf.keras.layers.Layer): | |
def __init__(self, **kwargs): | |
super(ExpandBatchLayer, self).__init__(**kwargs) | |
def call(self, input): | |
queries, docs = input | |
batch, num_docs, embedding_dims = tf.unstack(tf.shape(docs)) | |
expanded_queries = tf.gather(queries, tf.zeros([num_docs], tf.int32), axis=1) | |
return tf.concat([expanded_queries, docs], axis=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
query_input = tf.keras.layers.Input(shape=(1, EMBEDDING_DIMS, ), dtype=tf.float32, name='query') | |
docs_input = tf.keras.layers.Input(shape=(NUM_DOCS_PER_QUERY, EMBEDDING_DIMS, ), dtype=tf.float32, | |
name='docs') | |
expand_batch = ExpandBatchLayer(name='expand_batch') | |
dense_1 = tf.keras.layers.Dense(units=3, activation='linear', name='dense_1') | |
dense_out = tf.keras.layers.Dense(units=1, activation='linear', name='scores') | |
scores_prob_dist = tf.keras.layers.Dense(units=NUM_DOCS_PER_QUERY, activation='softmax', | |
name='scores_prob_dist') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hist = model.fit( | |
[query_embeddings, docs_averaged_embeddings], | |
relevance_grades_prob_dist, | |
epochs=50, | |
verbose=False | |
) |