Skip to content

Instantly share code, notes, and snippets.

@eustin
Created May 25, 2020 22:40
Show Gist options
  • Save eustin/59d2bed1da0df483625737b4608ae452 to your computer and use it in GitHub Desktop.
Save eustin/59d2bed1da0df483625737b4608ae452 to your computer and use it in GitHub Desktop.
# inputs
content_input = tf.keras.layers.Input(shape=(1, ), dtype=tf.int32, name='content_word')
context_input = tf.keras.layers.Input(shape=(1, ), dtype=tf.int32, name='context_word')
# layers
embeddings = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE, output_dim=2, name='embeddings')
dot_prod = tf.keras.layers.Dot(axes=2, normalize=True, name='dot_product')
# graph
content_embedding = embeddings(content_input)
context_embedding = embeddings(context_input)
cosine_sim = tf.keras.layers.Flatten(name='flatten')(dot_prod([content_embedding, context_embedding]))
dense_out = tf.keras.layers.Dense(1, activation='sigmoid', name='sigmoid_out')(cosine_sim)
# model
model = tf.keras.models.Model(inputs=[content_input, context_input], outputs=[dense_out])
DECAY_RATE = 5e-6
LR = 0.1
optimiser = tf.keras.optimizers.SGD(learning_rate=LR, decay=DECAY_RATE)
model.compile(loss='binary_crossentropy', optimizer=optimiser, metrics=['accuracy'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment