Skip to content

Instantly share code, notes, and snippets.

@kusal1990
Created June 25, 2022 06:47
Show Gist options
  • Save kusal1990/c0e50d364759118239ee570c7f9faca6 to your computer and use it in GitHub Desktop.
Save kusal1990/c0e50d364759118239ee570c7f9faca6 to your computer and use it in GitHub Desktop.
from transformers import RobertaTokenizer, TFRobertaForMultipleChoice
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
def arc_preprocessor(dataset, tokenizer):
'''
This function will convert a given article, question, choices in a format:
<s> article </s> </s> question </s> </s> choices[0] </s>
<s> article </s> </s> question </s> </s> choices[1] </s>
<s> article </s> </s> question </s> </s> choices[2] </s>
<s> article </s> </s> question </s> </s> choices[3] </s>
After converting in this format the data will be tokenized using a given tokenizer.
This function will return 4 arrays namely, input_ids, attention_mask, token_type_ids and labels.
individual input_ids, token_type_ids, attention_mask shape will be as: [num_choices, max_seq_length]
'''
global MAX_LEN
all_input_ids = []
all_attention_mask = []
all_token_type_ids = []
for i in range(len(dataset)):
context = ' '.join(dataset['context'].iloc[i].split()[:400]) #Limiting the maximum length of context to 400 words.
question = dataset['only_question'].iloc[i]
options = dataset['options_list'].iloc[i]
choice_features = []
for j in range(len(options)):
option = options[j]
input_string = '<s>' + ' ' + context + ' ' + '</s>' + ' ' + '</s>' + ' ' + question + ' ' + '</s>' + ' ' + '</s>' + ' ' + option + ' ' + '</s>'
input_string = re.sub(r'\s+', ' ', input_string)
input_ids = tokenizer(input_string,
max_length=MAX_LEN,
add_special_tokens=False, truncation='only_first')['input_ids']
attention_mask = [1] * len(input_ids)
padding_id = tokenizer.pad_token_id
padding_length = MAX_LEN - len(input_ids)
input_ids = input_ids + [padding_id]*padding_length
attention_mask = attention_mask + [0]*padding_length
assert len(input_ids) == MAX_LEN
assert len(attention_mask) == MAX_LEN
choice_features.append({'input_ids':input_ids,
'attention_mask':attention_mask})
all_input_ids.append(np.asarray([cf['input_ids'] for cf in choice_features], dtype='int32'))
all_attention_mask.append(np.asarray([cf['attention_mask'] for cf in choice_features], dtype='int32'))
return all_input_ids, all_attention_mask
@kusal1990
Copy link
Author

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment