Skip to content

Instantly share code, notes, and snippets.

@negedng
Created October 30, 2019 11:34
Show Gist options
  • Save negedng/b6fa67a61f1b5b15f19eb6686a10a419 to your computer and use it in GitHub Desktop.
Save negedng/b6fa67a61f1b5b15f19eb6686a10a419 to your computer and use it in GitHub Desktop.
max_seq_length = 128 # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1",
trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment