Created
November 11, 2019 09:21
-
-
Save sobamchan/93ed747097898a75193096e0f91766f6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict | |
from functools import partial | |
import lineflow as lf | |
import lineflow.datasets as lfds | |
import lineflow.cross_validation as lfcv | |
from transformers import BertTokenizer | |
MAX_LEN = 256 | |
def preprocess(tokenizer: BertTokenizer, x: Dict) -> Dict: | |
# `x` contains that one sample from lineflow dataset. | |
# Example: | |
# { | |
# "id": "075e483d21c29a511267ef62bedc0461", | |
# "answer_key": "A", | |
# "options": {"A": "ignore", | |
# "B": "enforce", | |
# "C": "authoritarian", | |
# "D": "yell at", | |
# "E": "avoid"}, | |
# "stem": "The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?"} | |
# } | |
# Use BertTokenizer to encode (tokenize / indexize) two sentences. | |
inputs = tokenizer.encode_plus( | |
x["string1"], | |
x["string2"], | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
) | |
# Output of `tokenizer.encode_plus` is a dictionary. | |
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] | |
# For BERT, we need `attention_mask` along with `input_ids` as input. | |
attention_mask = [1] * len(input_ids) | |
# We are going to pad sequences. | |
padding_length = MAX_LEN - len(input_ids) | |
pad_id = tokenizer.pad_token_id | |
input_ids = input_ids + ([pad_id] * padding_length) | |
attention_mask = attention_mask + ([0] * padding_length) | |
token_type_ids = token_type_ids + ([pad_id] * padding_length) | |
assert len(input_ids) == MAX_LEN, "Error with input length {} vs {}".format(len(input_ids), MAX_LEN) | |
assert len(attention_mask) == MAX_LEN, "Error with input length {} vs {}".format(len(attention_mask), MAX_LEN) | |
assert len(token_type_ids) == MAX_LEN, "Error with input length {} vs {}".format(len(token_type_ids), MAX_LEN) | |
# Just a python list to `torch.tensor` | |
label = torch.tensor(int(x["quality"])).long() | |
input_ids = torch.tensor(input_ids) | |
attention_mask = torch.tensor(attention_mask) | |
token_type_ids = torch.tensor(token_type_ids) | |
# What we return will one instance in batch which `LightningModule.train_step` receives. | |
return { | |
"label": label, | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"token_type_ids": token_type_ids | |
} | |
def nonefilter(dataset): | |
filtered = [] | |
for x in dataset: | |
if x["string1"] is None: | |
continue | |
if x["string2"] is None: | |
continue | |
filtered.append(x) | |
return lf.Dataset(filtered) | |
def get_dataloader(): | |
# Load datasets (this runs download script for the first run) | |
train = lfds.MsrParaphrase("train") | |
test = lfds.MsrParaphrase("test") | |
# There are some empty entities. Just remove them quickly. | |
train = nonefilter(train) | |
test = nonefilter(test) | |
# Just split train dataset into train and val, so that we can use val for early stopping. | |
train, val = lfcv.split_dataset_random(train, int(len(train) * 0.8), seed=42) | |
batch_size = 8 | |
# Now the BERT Tokenizer comes! | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) | |
# The actual preprocessing is in this `preprocess` function. (it is defined above.) | |
preprocessor = partial(preprocess, tokenizer) | |
# Apply the preprocessing and make pytorch dataloaders. | |
train_dataloader = DataLoader( | |
train.map(preprocessor), | |
sampler=RandomSampler(train), | |
batch_size=batch_size | |
) | |
val_dataloader = DataLoader( | |
val.map(preprocessor), | |
sampler=SequentialSampler(val), | |
batch_size=batch_size | |
) | |
test_dataloader = DataLoader( | |
test.map(preprocessor), | |
sampler=SequentialSampler(test), | |
batch_size=batch_size | |
) | |
return train_dataloader, val_dataloader, test_dataloader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment