Skip to content

Instantly share code, notes, and snippets.

Last active March 30, 2022 15:05
Show Gist options
  • Save SinclairCoder/f50934368364dac7c2a4337ae0b9016c to your computer and use it in GitHub Desktop.
Save SinclairCoder/f50934368364dac7c2a4337ae0b9016c to your computer and use it in GitHub Desktop.
Uncomment the L77-78 to test `max_input_len`
from pytorch_lightning import seed_everything
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer, AutoConfig, BartTokenizer, BartForConditionalGeneration
from import Dataset
from import DataLoader
import pytorch_lightning as pl
import os
class MyDataset(Dataset):
def __init__(self, tokenizer, raw_inputs, raw_targets, max_input_len=128, max_output_len=128):
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.tokenizer = tokenizer
self.inputs = []
self.targets = []
self._build_examples(raw_inputs, raw_targets)
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
source_ids = self.inputs[index]["input_ids"].squeeze()
target_ids = self.targets[index]["input_ids"].squeeze()
src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze
target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze
return {"source_ids": source_ids, "source_mask": src_mask,
"target_ids": target_ids, "target_mask": target_mask}
def _build_examples(self, raw_inputs, raw_targets):
for i in range(len(raw_inputs)):
# change input and target to two strings
input = raw_inputs[i]
target = raw_targets[i]
# batch_encode_plus
tokenized_input = self.tokenizer(
[input], max_length=self.max_input_len, padding="max_length",
truncation=True, return_tensors="pt"
tokenized_target = self.tokenizer(
[target], max_length=self.max_output_len, padding="max_length",
truncation=True, return_tensors="pt"
# toy dataset
raw_inputs = ["can't wait wait for my next visit.",
# "their sake list was extensive, but we were looking for purple haze, which wasn't listed but made for us upon request!", #
# "the spicy tuna roll was unusually good and the rock shrimp tempura was awesome, great appetizer to share!",
# "we love th pink pony."
raw_targets = ['restaurant general is great because it is NULL',
# "drinks style options is great because sake list is extensive [SSEP] service general is great because it is NULL",
# "food quality is great because spicy tuna roll is good [SSEP] food quality is great because rock shrimp tempura is awesome",
# "restaurant general is great because pink pony is love"
if __name__ == '__main__':
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')
optimizer = AdamW(model.parameters(), lr=3e-4, eps=1e-8)
# max_input_len is max length of input text.
train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=300, max_output_len=128)
# train_dataset = MyDataset(tokenizer, raw_inputs, raw_targets, max_input_len=200, max_output_len=128)
dataloader = DataLoader(train_dataset, batch_size=1, drop_last=False, shuffle=False, num_workers=4)
for idx, batch in enumerate(dataloader):
batch = {k:v.cuda() for k,v in batch.items()}
lm_labels = batch["target_ids"]
lm_labels[lm_labels[:, :] == tokenizer.pad_token_id] = -100
outputs = model(
# decoder_attention_mask=batch['target_mask']
# print(f'{idx}: {outputs[0]}')
# if idx>1:
# break
loss = outputs[0]
# print(outputs[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment