This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
NUM_CONTROL_USERS = 6 # total number of users in control group | |
control_users = np.zeros(NUM_CONTROL_USERS) | |
#> array([0., 0., 0., 0., 0., 0.]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rng = np.random.default_rng(1337) # random number generator using the one true seed | |
sns.set_style('white') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
from statsmodels.stats.proportion import score_test_proportions_2indep | |
from scipy import stats | |
from numba import njit, prange | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
import seaborn as sns |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
https://upload.wikimedia.org/wikipedia/commons/c/c5/%21%21%21_Mdina_buildings_05.jpg | |
https://upload.wikimedia.org/wikipedia/commons/d/d5/%21-2011-debowa-leka-palac-abri.jpg | |
https://upload.wikimedia.org/wikipedia/commons/c/c2/%21-20100523-szlichtyngowa-wiatrak-abri.jpg | |
https://upload.wikimedia.org/wikipedia/commons/c/ca/%21Illegal_street_ads_on_a_bicycle.jpg | |
https://upload.wikimedia.org/wikipedia/commons/f/fc/%21-2010-hetmanice-kosciol-abri.jpg | |
https://upload.wikimedia.org/wikipedia/commons/2/27/%21-2011-debowa-leka-kosciol-abri.jpg | |
https://upload.wikimedia.org/wikipedia/commons/0/0b/%22%2Barya%2B%22_penjual_jamu_tradisional_-_%EA%A6%A2%EA%A6%BA%EA%A6%B4%EA%A6%AD%EA%A7%80%EA%A6%A2%EA%A6%BA%EA%A6%B4%EA%A6%AD%EA%A6%A4%EA%A7%80_%EA%A6%97%EA%A6%A9%EA%A6%B8_%EA%A6%A0%EA%A6%BF%EA%A6%A2%EA%A6%B6%EA%A6%B1%EA%A6%B6%EA%A6%AA%EA%A6%BA%EA%A6%B4%EA%A6%A4%EA%A6%AD%EA%A7%80_Pilangsari_2019.jpg | |
https://upload.wikimedia.org/wikipedia/commons/c/cd/%22..._Although_a_dress_uniform_is_not_a_part_of_the_regular_equipment%2C_most_of |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hist = model.fit( | |
[query_embeddings, docs_averaged_embeddings], | |
relevance_grades_prob_dist, | |
epochs=50, | |
verbose=False | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
query_input = tf.keras.layers.Input(shape=(1, EMBEDDING_DIMS, ), dtype=tf.float32, name='query') | |
docs_input = tf.keras.layers.Input(shape=(NUM_DOCS_PER_QUERY, EMBEDDING_DIMS, ), dtype=tf.float32, | |
name='docs') | |
expand_batch = ExpandBatchLayer(name='expand_batch') | |
dense_1 = tf.keras.layers.Dense(units=3, activation='linear', name='dense_1') | |
dense_out = tf.keras.layers.Dense(units=1, activation='linear', name='scores') | |
scores_prob_dist = tf.keras.layers.Dense(units=NUM_DOCS_PER_QUERY, activation='softmax', | |
name='scores_prob_dist') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ExpandBatchLayer(tf.keras.layers.Layer): | |
def __init__(self, **kwargs): | |
super(ExpandBatchLayer, self).__init__(**kwargs) | |
def call(self, input): | |
queries, docs = input | |
batch, num_docs, embedding_dims = tf.unstack(tf.shape(docs)) | |
expanded_queries = tf.gather(queries, tf.zeros([num_docs], tf.int32), axis=1) | |
return tf.concat([expanded_queries, docs], axis=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
NUM_DOCS_PER_QUERY = 5 | |
EMBEDDING_DIMS = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
batch_loss = tf.reduce_mean(per_example_loss) | |
print(batch_loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
per_example_loss = tf.reduce_sum( | |
relevance_grades_prob_dist * tf.math.log(relevance_grades_prob_dist / scores_prob_dist), | |
axis=-1 | |
) | |
print(per_example_loss) |