Skip to content

Instantly share code, notes, and snippets.

@negedng
Created October 30, 2019 11:47
Show Gist options
  • Save negedng/e9a8e15a76eb9dc6dfb05a579f181be0 to your computer and use it in GitHub Desktop.
Save negedng/e9a8e15a76eb9dc6dfb05a579f181be0 to your computer and use it in GitHub Desktop.
def get_masks(tokens, max_seq_length):
"""Mask for padding"""
if len(tokens)>max_seq_length:
raise IndexError("Token length more than max seq length!")
return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))
def get_segments(tokens, max_seq_length):
"""Segments: 0 for the first sequence, 1 for the second"""
if len(tokens)>max_seq_length:
raise IndexError("Token length more than max seq length!")
segments = []
current_segment_id = 0
for token in tokens:
segments.append(current_segment_id)
if token == "[SEP]":
current_segment_id = 1
return segments + [0] * (max_seq_length - len(tokens))
def get_ids(tokens, tokenizer, max_seq_length):
"""Token ids from Tokenizer vocab"""
token_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
return input_ids
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment