Skip to content

Instantly share code, notes, and snippets.

@sobamchan
Last active October 18, 2019 08:31
Show Gist options
  • Save sobamchan/54136d0bc0407507ec5994e23c19a0e3 to your computer and use it in GitHub Desktop.
Save sobamchan/54136d0bc0407507ec5994e23c19a0e3 to your computer and use it in GitHub Desktop.
from typing import List, Dict, Callable
from collections import OrderedDict
from functools import partial
import lineflow as lf
import lineflow.datasets as lfds
import lineflow.cross_validation as lfcv
import torch
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from transformers import BertForSequenceClassification, BertTokenizer, AdamW
MAX_LEN = 256
NUM_LABELS = 2
def _pad(x: List[int], pad_idx: int, maxlen: int) -> List[int]:
padded: List[int]
if len(x) > maxlen:
padded = x[:maxlen]
else:
padded = x + [pad_idx] * (maxlen - len(x))
return padded
def preprocess(tokenizer: BertTokenizer, x: Dict) -> Callable:
inputs = tokenizer.encode_plus(
x["string1"],
x["string2"],
add_special_tokens=True,
max_length=MAX_LEN,
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
attention_mask = [1] * len(input_ids)
padding_length = MAX_LEN - len(input_ids)
padding_length = MAX_LEN - len(input_ids)
pad_id = tokenizer.pad_token_id
input_ids = input_ids + ([pad_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([pad_id] * padding_length)
assert len(input_ids) == MAX_LEN, "Error with input length {} vs {}".format(len(input_ids), MAX_LEN)
assert len(attention_mask) == MAX_LEN, "Error with input length {} vs {}".format(len(attention_mask), MAX_LEN)
assert len(token_type_ids) == MAX_LEN, "Error with input length {} vs {}".format(len(token_type_ids), MAX_LEN)
label = torch.tensor(int(x["quality"])).long()
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
token_type_ids = torch.tensor(token_type_ids)
return {
"label": label,
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids
}
def nonefilter(dataset):
filtered = []
for x in dataset:
if x["string1"] is None:
continue
if x["string2"] is None:
continue
filtered.append(x)
return lf.Dataset(filtered)
def get_dataloader():
train = lfds.MsrParaphrase("train")
train = nonefilter(train)
test = lfds.MsrParaphrase("test")
test = nonefilter(test)
train, val = lfcv.split_dataset_random(train, int(len(train) * 0.8), seed=42)
batch_size = 8
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
preprocessor = partial(preprocess, tokenizer)
train_dataloader = DataLoader(
train.map(preprocessor),
sampler=RandomSampler(train),
batch_size=batch_size
)
val_dataloader = DataLoader(
val.map(preprocessor),
sampler=SequentialSampler(val),
batch_size=batch_size
)
test_dataloader = DataLoader(
test.map(preprocessor),
sampler=SequentialSampler(test),
batch_size=batch_size
)
return train_dataloader, val_dataloader, test_dataloader
class Model(pl.LightningModule):
def __init__(self):
super(Model, self).__init__()
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=NUM_LABELS)
self.model = model
train_dataloader, val_dataloader, test_dataloader = get_dataloader()
self._train_dataloader = train_dataloader
self._val_dataloader = val_dataloader
self._test_dataloader = test_dataloader
def configure_optimizers(self):
param_optimizer = list(self.model.named_parameters())
no_decay = ["bias", "gamma", "beta"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay_rate": 0.01
},
{
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
"weight_decay_rate": 0.0
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=2e-5,
)
return optimizer
def training_step(self, batch, batch_idx):
labels = batch["label"]
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
token_type_ids = batch["token_type_ids"]
loss, _ = self.model(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
labels=labels
)
tqdm_dict = {"train_loss": loss}
output = OrderedDict({
"loss": loss,
"progress_bar": tqdm_dict,
"log": tqdm_dict
})
return output
def validation_step(self, batch, batch_idx):
labels = batch["label"]
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
token_type_ids = batch["token_type_ids"]
loss, logits = self.model(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
labels=labels
)
labels_hat = torch.argmax(logits, dim=1)
correct_count = torch.sum(labels == labels_hat)
if self.on_gpu:
correct_count = correct_count.cuda(loss.device.index)
output = OrderedDict({
"val_loss": loss,
"correct_count": correct_count,
"batch_size": len(labels)
})
return output
def validation_end(self, outputs):
val_acc = sum([out["correct_count"] for out in outputs]).float() / sum(out["batch_size"] for out in outputs)
val_loss = sum([out["val_loss"] for out in outputs]) / len(outputs)
tqdm_dict = {
"val_loss": val_loss,
"val_acc": val_acc,
}
result = {"progress_bar": tqdm_dict, "log": tqdm_dict, "val_loss": val_loss}
return result
def test_step(self, batch, batch_idx):
labels = batch["label"]
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
token_type_ids = batch["token_type_ids"]
loss, logits = self.model(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
labels=labels
)
labels_hat = torch.argmax(logits, dim=1)
correct_count = torch.sum(labels == labels_hat)
if self.on_gpu:
correct_count = correct_count.cuda(loss.device.index)
output = OrderedDict({
"test_loss": loss,
"correct_count": correct_count,
"batch_size": len(labels)
})
return output
def test_end(self, outputs):
test_acc = sum([out["correct_count"] for out in outputs]).float() / sum(out["batch_size"] for out in outputs)
test_loss = sum([out["test_loss"] for out in outputs]) / len(outputs)
tqdm_dict = {
"test_loss": test_loss,
"test_acc": test_acc,
}
result = {"progress_bar": tqdm_dict, "log": tqdm_dict}
return result
@pl.data_loader
def train_dataloader(self):
return self._train_dataloader
@pl.data_loader
def val_dataloader(self):
return self._val_dataloader
@pl.data_loader
def test_dataloader(self):
return self._test_dataloader
if __name__ == "__main__":
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.0,
patience=3,
verbose=True,
mode="min"
)
trainer = pl.Trainer(
gpus=1,
early_stop_callback=early_stop_callback,
)
model = Model()
trainer.fit(model)
trainer.test()
@yasufumy
Copy link

yasufumy commented Oct 16, 2019

Maybe, you can use functools.partial as follows:

from functools import partial


def preprocess(tokenizer: BertTokenizer, x: Dict) -> Dict[str, torch.Tensor]:
    s1_idxs = tokenizer.encode(x["string1"])
    s2_idxs = tokenizer.encode(x["string2"])
    idxs = [tokenizer.cls_token_id] + s1_idxs + [tokenizer.sep_token_id] + s2_idxs + [tokenizer.sep_token_id]

    pad_idx = tokenizer.pad_token_id
    idxs = _pad(idxs, pad_idx, MAX_LEN)
    mask = [float(i > 0) for i in idxs]
    segment = [0.0] + ([0.0] * len(s1_idxs)) + [0.0] + ([1.0] * len(s2_idxs)) + [1.0]
    segment = _pad(segment, pad_idx, MAX_LEN)
    assert len(idxs) == len(mask) == len(segment), f"{len(idxs)}, {len(mask)}, {len(segment)}"

    idxs = torch.tensor(idxs)
    mask = torch.tensor(mask)
    segment = torch.tensor(segment).long()
    label = torch.tensor(int(x["quality"])).long()

    return {
            "label": label,
            "input": idxs,
            "mask": mask,
            "segment": segment
            }


def get_dataloader():
    train = lfds.MsrParaphrase("train").filter(lambda x: x['string1'] and x['string2'])
    test = lfds.MsrParaphrase("test").filter(lambda x: x['string1'] and x['string2'])
    batch_size = 16

    preprocessor = partial(preprocess, BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True))

    train_dataloader = DataLoader(
            train.map(preprocessor),
            sampler=RandomSampler(train),
            batch_size=batch_size
            )
    test_dataloader = DataLoader(
            test.map(preprocessor),
            sampler=SequentialSampler(test),
            batch_size=batch_size
            )

    return train_dataloader, test_dataloader

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment