Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active November 12, 2020 09:39
Show Gist options
  • Select an option

  • Save gaphex/75d56a6a70c310d906bc7998619e1b1d to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/75d56a6a70c310d906bc7998619e1b1d to your computer and use it in GitHub Desktop.
def build_model(module_path, seq_len = 24, tune_lr=6, loss = softmax_loss):
inp_anc = tf.keras.Input(shape=(1, ), dtype=tf.string)
inp_pos = tf.keras.Input(shape=(1, ), dtype=tf.string)
inp_neg = tf.keras.Input(shape=(1, ), dtype=tf.string)
sent_encoder = BertLayer(module_path, seq_len, n_tune_layers=tune_lr, do_preprocessing=True,
verbose=False, pooling="mean", trainable=True, tune_embeddings=False)
anc_enc = sent_encoder(inp_anc)
pos_enc = sent_encoder(inp_pos)
neg_enc = sent_encoder(inp_neg)
loss = tf.keras.layers.Lambda(loss)([anc_enc, pos_enc, neg_enc])
sim = tf.keras.layers.Lambda(cosine_similarity)([anc_enc, pos_enc])
trn_model = tf.keras.models.Model(inputs=[inp_anc, inp_pos, inp_neg], outputs=[loss])
enc_model = tf.keras.models.Model(inputs=inp_anc, outputs=[anc_enc])
sim_model = tf.keras.models.Model(inputs=[inp_anc, inp_pos], outputs=[sim])
trn_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5, ),
loss=mean_loss,
metrics=[])
trn_model.summary()
mdict = {
"enc_model": enc_model,
"sim_model": sim_model,
"trn_model": trn_model
}
return mdict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment