Created
June 27, 2020 17:54
-
-
Save williamFalcon/645019619bdd897d135d232556bcf27d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data as tud | |
import torch | |
from typing import List | |
import random | |
import nlp | |
def prepare_dataset(tokenizer, split="train", max_length=120, num_datapoints=100_000): | |
"""Prepares WikiText-103 dataset""" | |
wikitext = nlp.load_dataset("wikitext", "wikitext-103-v1") | |
data = [x["text"] for x in wikitext[split]][:num_datapoints] | |
data = "".join(data) | |
token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(data)) | |
chunked_token_ids = chunks(token_ids, max_length, tokenizer) | |
data = Data(chunked_token_ids, tokenizer) | |
return data | |
def chunks(lst, n, tokenizer): | |
"""Yield successive n-sized chunks from lst.""" | |
_chunks = [] | |
for i in range(0, len(lst), n): | |
ids = [tokenizer.cls_token_id] + lst[i : i + n] + [tokenizer.sep_token_id] | |
_chunks.append(torch.tensor(ids)) | |
return _chunks | |
def noise_text_input(text: str, noise_prob=0.2): | |
"""Takes a string, returns noised version of it""" | |
splitted = text.split(" ") | |
bool_mask = torch.empty(len(splitted)).uniform_() > 1 - noise_prob | |
noised = [] | |
for word, boolean in zip(splitted, bool_mask): | |
if boolean: | |
if len(word) > 1: | |
idx = random.randint(1, len(word) - 1) | |
noised.append(word[:idx]) | |
noised.append(word[idx:]) | |
else: | |
noised.append(word) | |
return " ".join(noised) | |
def make_transformer_inputs( | |
input_ids, max_length, padding_value, prefix="", make_labels=False, **kwargs | |
): | |
lengths = [s.size(0) for s in input_ids] | |
max_len = max(lengths) | |
if max_len > max_length: | |
max_len = max_length | |
out_dims = (len(input_ids), max_len) | |
padded_input_ids = input_ids[0].data.new(*out_dims).fill_(padding_value) | |
attention_mask = padded_input_ids.clone() | |
token_type_ids = padded_input_ids.clone() | |
for i, tensor in enumerate(input_ids): | |
length = tensor.size(0) | |
if length > max_length: | |
length = max_length | |
tensor = tensor[:length] | |
padded_input_ids[i, :length] = tensor | |
attention_mask[i, :length] = torch.ones_like(tensor) | |
batch = { | |
f"{prefix}input_ids": padded_input_ids, | |
f"{prefix}attention_mask": attention_mask, | |
f"{prefix}token_type_ids": token_type_ids, | |
} | |
if make_labels: | |
lm_labels = padded_input_ids.clone() | |
lm_labels[lm_labels == padding_value] = -100 | |
batch["lm_labels"] = lm_labels | |
batch.update(kwargs) | |
return batch | |
class Data(tud.Dataset): | |
def __init__(self, token_ids: List[torch.Tensor], tokenizer, noise_prob=0.2): | |
self.token_ids = token_ids | |
self.tokenizer = tokenizer | |
self.len = len(token_ids) | |
self.noise_prob = noise_prob | |
def __len__(self): | |
return self.len | |
def __getitem__(self, idx): | |
tgt_ids = self.token_ids[idx] | |
decoded = self.tokenizer.decode(tgt_ids, skip_special_tokens=True) | |
noised = noise_text_input(decoded, self.noise_prob) | |
src = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(noised)) | |
src = [self.tokenizer.cls_token_id] + src + [self.tokenizer.sep_token_id] | |
src_ids = torch.tensor(src) | |
return dict(src_input_ids=src_ids, tgt_input_ids=tgt_ids) | |
class Collater: | |
def __init__(self, tokenizer, max_length=128): | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __call__(self, batch: List): | |
src = [x["src_input_ids"] for x in batch] | |
tgt = [x["tgt_input_ids"] for x in batch] | |
src_batch = self.collate(src) | |
tgt_batch = self.collate(tgt, "decoder_", make_labels=True) | |
src_batch.update(tgt_batch) | |
return src_batch | |
def collate(self, input_ids, prefix="", make_labels=False): | |
return make_transformer_inputs( | |
input_ids, self.max_length, self.tokenizer.pad_token_id, prefix, make_labels | |
) | |
import pytorch_lightning as pl | |
from transformers import EncoderDecoderModel, BertTokenizer | |
import torch | |
import torch_optimizer | |
import torch.utils.data as tud | |
class NoamScheduler(torch.optim.lr_scheduler.LambdaLR): | |
def __init__(self, optimizer, num_warmup_steps=1000, last_epoch=-1): | |
assert num_warmup_steps > 0 | |
normalize = 1 / (num_warmup_steps * num_warmup_steps ** -1.5) | |
super().__init__( | |
optimizer, | |
lambda step: normalize | |
* min((step + 1) ** -0.5, (step + 1) * num_warmup_steps ** -1.5), | |
last_epoch, | |
) | |
class Model(pl.LightningModule): | |
def __init__( | |
self, hparams, train_dataset=None, val_dataset=None, test_dataset=None | |
): | |
super().__init__() | |
self.hparams = hparams | |
self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( | |
"bert-base-cased", "bert-base-cased" | |
) # initialize Bert2Bert | |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased") | |
self.collater = Collater(self.tokenizer, self.hparams.max_length) | |
def setup(self, step) -> None: | |
self.train_dataset = prepare_dataset( | |
self.tokenizer, | |
"validation", # to save time | |
self.hparams.max_length, | |
self.hparams.num_datapoints, | |
) | |
self.val_dataset = prepare_dataset( | |
self.tokenizer, | |
"validation", | |
self.hparams.max_length, | |
self.hparams.num_datapoints, | |
) | |
self.test_dataset = prepare_dataset( | |
self.tokenizer, "test", self.hparams.max_length, self.hparams.num_datapoints | |
) | |
def train_dataloader(self): | |
return tud.DataLoader( | |
self.test_dataset, | |
batch_size=self.hparams.train_bs, | |
shuffle=True, | |
num_workers=self.hparams.num_workers or 4, | |
collate_fn=self.collater, | |
) | |
def val_dataloader(self): | |
return tud.DataLoader( | |
self.val_dataset, | |
batch_size=self.hparams.val_bs, | |
shuffle=False, | |
num_workers=4, | |
collate_fn=self.collater, | |
) | |
def test_dataloader(self): | |
return tud.DataLoader( | |
self.test_dataset, | |
self.hparams.val_bs, | |
False, | |
num_workers=self.hparams.num_workers or 4, | |
collate_fn=self.collater, | |
) | |
def forward(self, batch): | |
return self.model(**batch) | |
def training_step(self, batch, batch_idx): | |
loss, logits, *_ = self(batch) | |
self.logger.log_metrics({"loss": loss.cpu()}) | |
output = {"loss": loss} | |
return output | |
def validation_step(self, batch, batch_idx): | |
return self._shared_val_step(batch, batch_idx, "val") | |
def validation_epoch_end(self, output): | |
return self._shared_val_end(output, "val") | |
def test_step(self, batch, batch_idx): | |
return self._shared_val_step(batch, batch_idx, "test") | |
def test_epoch_end(self, output): | |
return self._shared_val_end(output, "test") | |
def _shared_val_step(self, batch, batch_idx, prefix): | |
loss, logits, *_ = self(batch) | |
preds = logits.argmax(-1) # bs x seqlen | |
lm_labels = batch["lm_labels"] # bs x seqlen | |
acc_mask = lm_labels[:, 1:].ne(-100) | |
correct = preds[:, :-1].eq(lm_labels[:, 1:]) # bs x (seqlen - 1) | |
frac_tokens_correct = correct.masked_select(acc_mask).float().mean() | |
correct[~acc_mask] = True | |
frac_seqs_correct = correct.all(1).float().mean() | |
logs = { | |
f"{prefix}_loss": loss, | |
"frac_tokens_correct": frac_tokens_correct, | |
"frac_seqs_correct": frac_seqs_correct, | |
} | |
return logs | |
def _shared_val_end(self, output, prefix): | |
output = self.collate(output) | |
logs = {"log": output, f"{prefix}_loss": output[f"{prefix}_loss"]} | |
# self.logger.log_metrics(output) | |
return logs | |
def configure_optimizers(self): | |
opt_class = getattr(torch_optimizer, self.hparams.optimizer) | |
no_decay = ["bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p | |
for n, p in self.named_parameters() | |
if not any(nd in n for nd in no_decay) | |
], | |
"weight_decay": self.hparams.optimizer_kwargs.weight_decay or 1e-7, | |
}, | |
{ | |
"params": [ | |
p | |
for n, p in self.named_parameters() | |
if any(nd in n for nd in no_decay) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
self.optimizer = opt_class( | |
optimizer_grouped_parameters, **self.hparams.optimizer_kwargs | |
) | |
scheduler = NoamScheduler( | |
self.optimizer, self.hparams.schedulers_kwargs.num_warmup_steps | |
) | |
self.scheduler = {"scheduler": scheduler, "interval": "step"} | |
return [self.optimizer], [self.scheduler] | |
def collate(self, output): | |
keys = output[0].keys() | |
return_dict = {} | |
for key in keys: | |
tensor = output[0][key] | |
if tensor.dim() == 0: | |
return_dict[key] = torch.stack([x[key] for x in output]).mean() | |
elif tensor.dim() == 1: | |
return_dict[key] = torch.cat([x[key] for x in output]).mean() | |
return return_dict | |
hparams = { | |
"name": "MY-WANDB-NAME", | |
"project": "MY-WANDB-PROJECT", | |
"train_bs": 4, | |
"val_bs": 4, | |
"num_workers": 4, | |
"max_length": 160, | |
"num_datapoints": 100_000, | |
"optimizer": "Ranger", | |
"optimizer_kwargs": { | |
"lr": 3e-4, | |
"alpha": 0.5, | |
"betas": [0.95, 0.999], | |
"eps": 1e-5, | |
"weight_decay": 1e-3, | |
# "use_gc": True, | |
}, | |
"schedulers_kwargs": {"num_warmup_steps": 1000}, | |
"trainer_kwargs": { | |
"gpus": 2, | |
"gradient_clip_val": 0.5, | |
"accumulate_grad_batches": 4, | |
"min_epochs": 5, | |
"max_epochs": 100, | |
"precision": 32, | |
"distributed_backend": 'ddp', ### Change this to "ddp" when on multi-gpu to see the bug | |
}, | |
} | |
import wandb | |
wandb.login() | |
from omegaconf import OmegaConf | |
from pytorch_lightning.loggers import WandbLogger | |
def train(hparams): | |
hparams = OmegaConf.create(hparams) | |
print(hparams.pretty()) | |
log = WandbLogger(name=hparams.name, project=hparams.project) | |
checkpoint = pl.callbacks.ModelCheckpoint( | |
filepath="checkpoints/", verbose=True, monitor="val_loss", mode="min" | |
) | |
trainer = pl.Trainer( | |
logger=log, checkpoint_callback=checkpoint, **hparams.trainer_kwargs | |
) | |
model = Model(hparams) | |
trainer.fit(model) | |
if __name__ == '__main__': | |
train(hparams) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment