Created
May 5, 2020 19:01
-
-
Save a7v8x/32544f0a452e092b2c54705c810b8eb0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# map to the expected input to TFBertForSequenceClassification, see here | |
def map_example_to_dict(input_ids, attention_masks, token_type_ids, label): | |
return { | |
"input_ids": input_ids, | |
"token_type_ids": token_type_ids, | |
"attention_mask": attention_masks, | |
}, label | |
def encode_examples(ds, limit=-1): | |
# prepare list, so that we can build up final TensorFlow dataset from slices. | |
input_ids_list = [] | |
token_type_ids_list = [] | |
attention_mask_list = [] | |
label_list = [] | |
if (limit > 0): | |
ds = ds.take(limit) | |
for review, label in tfds.as_numpy(ds): | |
bert_input = convert_example_to_feature(review.decode()) | |
input_ids_list.append(bert_input['input_ids']) | |
token_type_ids_list.append(bert_input['token_type_ids']) | |
attention_mask_list.append(bert_input['attention_mask']) | |
label_list.append([label]) | |
return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment