Created
June 25, 2022 07:30
-
-
Save kusal1990/c8d0ab0f9714fdcfd8d319784e25adff to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
tokenizer = AutoTokenizer.from_pretrained('t5-small') | |
def arc_preprocessor(dataset, tokenizer): | |
''' | |
This function will convert a given context, question, choices in a format: | |
input: question \n options \n context </s> | |
target: label </s> | |
After converting in this format the data will be tokenized using a given tokenizer. | |
This function will return 4 arrays namely, input_ids, attention_mask, token_type_ids and labels. | |
''' | |
global MAX_LEN | |
all_input_ids = [] | |
all_attention_mask = [] | |
all_decoder_input_ids = [] | |
all_labels = [] | |
for i in range(len(dataset)): | |
context = ' '.join(dataset['context'].iloc[i].split()[:350]) #Limiting the maximum length of context to 350 words. | |
question = dataset['only_question'].iloc[i] | |
options = dataset['only_answers'].iloc[i] | |
target = dataset['Answer'].iloc[i] | |
choice_features = [] | |
input_string = question + ' ' + '\\n' + ' ' + options + ' ' + '</s>' | |
decoder_input = tokenizer.pad_token + ' ' + target | |
target = target + ' ' + '</s>' | |
input_ids = tokenizer.encode(input_string, truncation=True, max_length=MAX_LEN) | |
decoder_input_ids = tokenizer.encode(decoder_input, max_length=decoder_max_len, truncation=True) #Max length of a answer is 23 | |
labels = tokenizer.encode(target, max_length=decoder_max_len, truncation=True) | |
attention_mask = [1] * len(input_ids) | |
padding_id = tokenizer.pad_token_id | |
padding_length = MAX_LEN - len(input_ids) | |
input_ids = input_ids + [padding_id]*padding_length | |
attention_mask = attention_mask + [0]*padding_length | |
deocder_padding_length = decoder_max_len - len(decoder_input_ids) | |
decoder_input_ids = decoder_input_ids + [padding_id]*deocder_padding_length | |
labels = labels + [padding_id]*deocder_padding_length | |
assert len(input_ids) == MAX_LEN | |
assert len(attention_mask) == MAX_LEN | |
all_input_ids.append(np.asarray(input_ids, dtype='int32')) | |
all_attention_mask.append(np.asarray(attention_mask, dtype='int32')) | |
all_decoder_input_ids.append(np.asarray(decoder_input_ids, dtype='int32')) | |
all_labels.append(np.asarray(labels, dtype='int32')) | |
return all_input_ids, all_attention_mask, all_decoder_input_ids, all_labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ok