Last active
March 30, 2022 14:56
-
-
Save SinclairCoder/677637621989983ed3cbb19e88dd0487 to your computer and use it in GitHub Desktop.
uncomment the L109 to test different `max_input_len`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import seed_everything | |
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer, AutoConfig, BartTokenizer, BartForConditionalGeneration | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
import pytorch_lightning as pl | |
import os | |
class MyDataset(Dataset): | |
def __init__(self, tokenizer, raw_inputs, raw_targets, max_input_len=128, max_output_len=128): | |
self.max_input_len = max_input_len | |
self.max_output_len = max_output_len | |
self.tokenizer = tokenizer | |
self.inputs = [] | |
self.targets = [] | |
self._build_examples(raw_inputs, raw_targets) | |
def __len__(self): | |
return len(self.inputs) | |
def __getitem__(self, index): | |
source_ids = self.inputs[index]["input_ids"].squeeze() | |
target_ids = self.targets[index]["input_ids"].squeeze() | |
src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze | |
target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze | |
return {"source_ids": source_ids, "source_mask": src_mask, | |
"target_ids": target_ids, "target_mask": target_mask} | |
def _build_examples(self, raw_inputs, raw_targets): | |
for i in range(len(raw_inputs)): | |
# change input and target to two strings | |
input = raw_inputs[i] | |
target = raw_targets[i] | |
# batch_encode_plus | |
tokenized_input = self.tokenizer( | |
[input], max_length=self.max_input_len, padding="max_length", | |
truncation=True, return_tensors="pt" | |
) | |
tokenized_target = self.tokenizer( | |
[target], max_length=self.max_output_len, padding="max_length", | |
truncation=True, return_tensors="pt" | |
) | |
self.inputs.append(tokenized_input) | |
self.targets.append(tokenized_target) | |
# toy dataset | |
raw_inputs = ["can't wait wait for my next visit.", | |
# "their sake list was extensive, but we were looking for purple haze, which wasn't listed but made for us upon request!", # | |
# "the spicy tuna roll was unusually good and the rock shrimp tempura was awesome, great appetizer to share!", | |
# "we love th pink pony." | |
] | |
raw_targets = ['restaurant general is great because it is NULL', | |
# "drinks style options is great because sake list is extensive [SSEP] service general is great because it is NULL", | |
# "food quality is great because spicy tuna roll is good [SSEP] food quality is great because rock shrimp tempura is awesome", | |
# "restaurant general is great because pink pony is love" | |
] | |
class MyT5FineTuner(pl.LightningModule): | |
""" | |
Fine tune a pre-trained T5 model | |
""" | |
def __init__(self, tfm_model, tokenizer): | |
super(MyT5FineTuner, self).__init__() | |
self.model = tfm_model | |
self.tokenizer = tokenizer | |
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, | |
decoder_attention_mask=None, labels=None): | |
return self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
) | |
def _step(self, batch): | |
lm_labels = batch["target_ids"] | |
lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100 | |
outputs = self( | |
input_ids=batch["source_ids"], | |
attention_mask=batch["source_mask"], | |
labels=lm_labels, | |
) | |
loss = outputs[0] | |
return loss | |
def training_step(self, batch, batch_idx): | |
loss = self._step(batch) | |
print(loss) | |
return loss | |
def configure_optimizers(self): | |
""" Prepare optimizer and schedule (linear warmup and decay) """ | |
optimizer = AdamW(self.model.parameters(), lr=3e-4, eps=1e-8) | |
return optimizer | |
def train_dataloader(self): | |
train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=300, max_output_len=128) | |
# train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=200, max_output_len=128) | |
dataloader = DataLoader(train_dataset, batch_size=1, | |
drop_last=False, shuffle=False, num_workers=1) | |
return dataloader | |
if __name__ == '__main__': | |
seed_everything(42) | |
tokenizer = T5Tokenizer.from_pretrained('t5-base') | |
tfm_model = T5ForConditionalGeneration.from_pretrained('t5-base') | |
model = MyT5FineTuner(tfm_model, tokenizer) | |
trainer = pl.Trainer() | |
trainer.fit(model) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment