Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active December 10, 2019 11:38
Show Gist options
  • Save gaphex/19c0d80982dbed111a2e2c97270ae02f to your computer and use it in GitHub Desktop.
Save gaphex/19c0d80982dbed111a2e2c97270ae02f to your computer and use it in GitHub Desktop.
inp = tf.keras.Input(shape=(1,), dtype=tf.string)
encoder = BertLayer(bert_path="./bert-module/", seq_len=48,
tune_embeddings=False, do_preprocessing=True,
pooling='cls', n_tune_layers=3, verbose=False)
pred = tf.keras.layers.Dense(1, activation='sigmoid')(encoder(inp))
model = tf.keras.models.Model(inputs=[inp], outputs=[pred])
model.summary()
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss="binary_crossentropy",
metrics=["accuracy"])
saver = keras.callbacks.ModelCheckpoint("bert_tuned.hdf5")
model.fit(trX, trY, validation_data=[tsX, tsY], batch_size=128, epochs=5, callbacks=[saver])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment