Last active
December 2, 2019 03:35
-
-
Save analyticsindiamagazine/c9c438a336dea2b6cc8447e0a7c2a2ce to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_masks(tokens, max_seq_length): | |
"""Mask for padding""" | |
if len(tokens)>max_seq_length: | |
#Cutting down the excess length | |
tokens = tokens[0:max_seq_length] | |
return [1]*len(tokens) | |
else : | |
return [1]*len(tokens) + [0] * (max_seq_length - len(tokens)) | |
def get_segments(tokens, max_seq_length): | |
if len(tokens)>max_seq_length: | |
#Cutting down the excess length | |
tokens = tokens[:max_seq_length] | |
segments = [] | |
current_segment_id = 0 | |
for token in tokens: | |
segments.append(current_segment_id) | |
if token == "[SEP]": | |
current_segment_id = 1 | |
return segments | |
else: | |
segments = [] | |
current_segment_id = 0 | |
for token in tokens: | |
segments.append(current_segment_id) | |
if token == "[SEP]": | |
current_segment_id = 1 | |
return segments + [0] * (max_seq_length - len(tokens)) | |
def get_ids(tokens, tokenizer, max_seq_length): | |
if len(tokens)>max_seq_length: | |
tokens = tokens[:max_seq_length] | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
return token_ids | |
else: | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
input_ids = token_ids + [0] * (max_seq_length-len(token_ids)) | |
return input_ids |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def prep(s, get = 'id'): | |
stokens = tokenizer.tokenize(s) | |
stokens = ["[CLS]"] + stokens + ["[SEP]"] | |
if get == 'id': | |
input_ids = get_ids(stokens, tokenizer, max_seq_length) | |
return input_ids | |
elif get == 'mask': | |
input_masks = get_masks(stokens, max_seq_length) | |
return input_masks | |
else: | |
input_segments = get_segments(stokens, max_seq_length) | |
return input_segments |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy() | |
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy() | |
tokenizer = FullTokenizer(vocab_file, do_lower_case) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment