Last active
March 31, 2022 01:46
-
-
Save SinclairCoder/3eee9d1cd78e81745de515ec594e6e2c to your computer and use it in GitHub Desktop.
Uncomment the L71-75 to test different PTMs and L79-80 to test `max_input_len` respectively
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import seed_everything | |
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer, AutoConfig, BartTokenizer, BartForConditionalGeneration | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
import pytorch_lightning as pl | |
import os | |
class MyDataset(Dataset): | |
def __init__(self, tokenizer, raw_inputs, raw_targets, max_input_len=128, max_output_len=128): | |
self.max_input_len = max_input_len | |
self.max_output_len = max_output_len | |
self.tokenizer = tokenizer | |
self.inputs = [] | |
self.targets = [] | |
self._build_examples(raw_inputs, raw_targets) | |
def __len__(self): | |
return len(self.inputs) | |
def __getitem__(self, index): | |
source_ids = self.inputs[index]["input_ids"].squeeze() | |
target_ids = self.targets[index]["input_ids"].squeeze() | |
src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze | |
target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze | |
return {"source_ids": source_ids, "source_mask": src_mask, | |
"target_ids": target_ids, "target_mask": target_mask} | |
def _build_examples(self, raw_inputs, raw_targets): | |
for i in range(len(raw_inputs)): | |
# change input and target to two strings | |
input = raw_inputs[i] | |
target = raw_targets[i] | |
# batch_encode_plus | |
tokenized_input = self.tokenizer( | |
[input], max_length=self.max_input_len, padding="max_length", | |
truncation=True, return_tensors="pt" | |
) | |
tokenized_target = self.tokenizer( | |
[target], max_length=self.max_output_len, padding="max_length", | |
truncation=True, return_tensors="pt" | |
) | |
self.inputs.append(tokenized_input) | |
self.targets.append(tokenized_target) | |
# toy dataset | |
raw_inputs = ["can't wait wait for my next visit.", | |
# "their sake list was extensive, but we were looking for purple haze, which wasn't listed but made for us upon request!", # | |
# "the spicy tuna roll was unusually good and the rock shrimp tempura was awesome, great appetizer to share!", | |
# "we love th pink pony." | |
] | |
raw_targets = ['restaurant general is great because it is NULL', | |
# "drinks style options is great because sake list is extensive [SSEP] service general is great because it is NULL", | |
# "food quality is great because spicy tuna roll is good [SSEP] food quality is great because rock shrimp tempura is awesome", | |
# "restaurant general is great because pink pony is love" | |
] | |
if __name__ == '__main__': | |
seed_everything(42) | |
# tokenizer = T5Tokenizer.from_pretrained('t5-base') | |
# model = T5ForConditionalGeneration.from_pretrained('t5-base') | |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') | |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') | |
model.cuda() | |
optimizer = AdamW(model.parameters(), lr=3e-4, eps=1e-8) | |
# max_input_len is max length of input text. | |
train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=300, max_output_len=128) | |
# train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=200, max_output_len=128) | |
dataloader = DataLoader(train_dataset, batch_size=1, drop_last=False, shuffle=False, num_workers=4) | |
for idx, batch in enumerate(dataloader): | |
batch = {k:v.cuda() for k,v in batch.items()} | |
lm_labels = batch["target_ids"] | |
lm_labels[lm_labels[:, :] == tokenizer.pad_token_id] = -100 | |
optimizer.zero_grad() | |
outputs = model( | |
input_ids=batch["source_ids"], | |
attention_mask=batch["source_mask"], | |
labels=lm_labels, | |
# decoder_attention_mask=batch['target_mask'] | |
) | |
# print(f'{idx}: {outputs[0]}') | |
# if idx>1: | |
# break | |
loss = outputs[0] | |
# print(outputs[1]) | |
loss.backward() | |
optimizer.step() | |
print(f'{idx}:{loss}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment