This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| NUM_CONTROL_USERS = 6 # total number of users in control group | |
| control_users = np.zeros(NUM_CONTROL_USERS) | |
| #> array([0., 0., 0., 0., 0., 0.]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| rng = np.random.default_rng(1337) # random number generator using the one true seed | |
| sns.set_style('white') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import numpy as np | |
| from statsmodels.stats.proportion import score_test_proportions_2indep | |
| from scipy import stats | |
| from numba import njit, prange | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| import seaborn as sns |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| https://upload.wikimedia.org/wikipedia/commons/c/c5/%21%21%21_Mdina_buildings_05.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/d/d5/%21-2011-debowa-leka-palac-abri.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/c/c2/%21-20100523-szlichtyngowa-wiatrak-abri.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/c/ca/%21Illegal_street_ads_on_a_bicycle.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/f/fc/%21-2010-hetmanice-kosciol-abri.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/2/27/%21-2011-debowa-leka-kosciol-abri.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/0/0b/%22%2Barya%2B%22_penjual_jamu_tradisional_-_%EA%A6%A2%EA%A6%BA%EA%A6%B4%EA%A6%AD%EA%A7%80%EA%A6%A2%EA%A6%BA%EA%A6%B4%EA%A6%AD%EA%A6%A4%EA%A7%80_%EA%A6%97%EA%A6%A9%EA%A6%B8_%EA%A6%A0%EA%A6%BF%EA%A6%A2%EA%A6%B6%EA%A6%B1%EA%A6%B6%EA%A6%AA%EA%A6%BA%EA%A6%B4%EA%A6%A4%EA%A6%AD%EA%A7%80_Pilangsari_2019.jpg | |
| https://upload.wikimedia.org/wikipedia/commons/c/cd/%22..._Although_a_dress_uniform_is_not_a_part_of_the_regular_equipment%2C_most_of |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| hist = model.fit( | |
| [query_embeddings, docs_averaged_embeddings], | |
| relevance_grades_prob_dist, | |
| epochs=50, | |
| verbose=False | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| query_input = tf.keras.layers.Input(shape=(1, EMBEDDING_DIMS, ), dtype=tf.float32, name='query') | |
| docs_input = tf.keras.layers.Input(shape=(NUM_DOCS_PER_QUERY, EMBEDDING_DIMS, ), dtype=tf.float32, | |
| name='docs') | |
| expand_batch = ExpandBatchLayer(name='expand_batch') | |
| dense_1 = tf.keras.layers.Dense(units=3, activation='linear', name='dense_1') | |
| dense_out = tf.keras.layers.Dense(units=1, activation='linear', name='scores') | |
| scores_prob_dist = tf.keras.layers.Dense(units=NUM_DOCS_PER_QUERY, activation='softmax', | |
| name='scores_prob_dist') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ExpandBatchLayer(tf.keras.layers.Layer): | |
| def __init__(self, **kwargs): | |
| super(ExpandBatchLayer, self).__init__(**kwargs) | |
| def call(self, input): | |
| queries, docs = input | |
| batch, num_docs, embedding_dims = tf.unstack(tf.shape(docs)) | |
| expanded_queries = tf.gather(queries, tf.zeros([num_docs], tf.int32), axis=1) | |
| return tf.concat([expanded_queries, docs], axis=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| NUM_DOCS_PER_QUERY = 5 | |
| EMBEDDING_DIMS = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| batch_loss = tf.reduce_mean(per_example_loss) | |
| print(batch_loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| per_example_loss = tf.reduce_sum( | |
| relevance_grades_prob_dist * tf.math.log(relevance_grades_prob_dist / scores_prob_dist), | |
| axis=-1 | |
| ) | |
| print(per_example_loss) |