Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active January 11, 2021 13:43
Show Gist options
  • Save gaphex/85906734b5574a96c6e7dba218c1d18f to your computer and use it in GitHub Desktop.
Save gaphex/85906734b5574a96c6e7dba218c1d18f to your computer and use it in GitHub Desktop.
Spec function for BERT token embedding module
def build_module_fn(config_path, vocab_path, do_lower_case=True):
def bert_module_fn(is_training):
"""Spec function for a token embedding module."""
input_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_ids")
input_mask = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_mask")
token_type = tf.placeholder(shape=[None, None], dtype=tf.int32, name="segment_ids")
config = BertConfig.from_json_file(config_path)
model = BertModel(config=config, is_training=is_training,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type)
seq_output = model.all_encoder_layers[-1]
pool_output = model.get_pooled_output()
vocab_file = tf.constant(value=vocab_path, dtype=tf.string, name="vocab_file")
lower_case = tf.constant(do_lower_case)
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)
input_map = {"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": token_type}
output_map = {"pooled_output": pool_output,
"sequence_output": seq_output}
output_info_map = {"vocab_file": vocab_file,
"do_lower_case": lower_case}
hub.add_signature(name="tokens", inputs=input_map, outputs=output_map)
hub.add_signature(name="tokenization_info", inputs={}, outputs=output_info_map)
return bert_module_fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment