-
-
Save bobycv06fpm/946c4f078f72394de4588ff4089183f3 to your computer and use it in GitHub Desktop.
Pytorch script for fine-tuning Pegasus Large model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script for fine-tuning Pegasus | |
Example usage: | |
# use XSum dataset as example, with first 1000 docs as training data | |
from datasets import load_dataset | |
dataset = load_dataset("xsum") | |
train_texts, train_labels = dataset['train']['document'][:1000], dataset['train']['summary'][:1000] | |
# use Pegasus Large model as base for fine-tuning | |
model_name = 'google/pegasus-large' | |
train_dataset, _, _ = prepare_data(model_name, train_texts, train_labels) | |
trainer = prepare_fine_tuning(model_name, train_dataset) | |
trainer.train() | |
Reference: | |
https://huggingface.co/transformers/master/custom_datasets.html | |
""" | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, Trainer, TrainingArguments | |
import torch | |
class PegasusDataset(torch.utils.data.Dataset): | |
def __init__(self, encodings, labels): | |
self.encodings = encodings | |
self.labels = labels | |
def __getitem__(self, idx): | |
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
item['labels'] = torch.tensor(self.labels['input_ids'][idx]) # torch.tensor(self.labels[idx]) | |
return item | |
def __len__(self): | |
return len(self.labels) | |
def prepare_data(model_name, | |
train_texts, train_labels, | |
val_texts=None, val_labels=None, | |
test_texts=None, test_labels=None): | |
""" | |
Prepare input data for model fine-tuning | |
""" | |
tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
prepare_val = False if val_texts is None or val_labels is None else True | |
prepare_test = False if test_texts is None or test_labels is None else True | |
def tokenize_data(texts, labels): | |
encodings = tokenizer(texts, truncation=True, padding=True) | |
decodings = tokenizer(labels, truncation=True, padding=True) | |
dataset_tokenized = PegasusDataset(encodings, decodings) | |
return dataset_tokenized | |
train_dataset = tokenize_data(train_texts, train_labels) | |
val_dataset = tokenize_data(val_texts, val_labels) if prepare_val else None | |
test_dataset = tokenize_data(test_texts, test_labels) if prepare_test else None | |
return train_dataset, val_dataset, test_dataset | |
def prepare_fine_tuning(model_name, train_dataset, val_dataset=None, freeze_encoder=False, output_dir='./results'): | |
""" | |
Prepare configurations and base model for fine-tuning | |
""" | |
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) | |
if freeze_encoder: | |
for param in model.model.encoder.parameters(): | |
param.requires_grad = False | |
if val_dataset is not None: | |
training_args = TrainingArguments( | |
output_dir=output_dir, # output directory | |
num_train_epochs=2000, # total number of training epochs | |
per_device_train_batch_size=1, # batch size per device during training, can increase if memory allows | |
per_device_eval_batch_size=1, # batch size for evaluation, can increase if memory allows | |
save_steps=500, # number of updates steps before checkpoint saves | |
save_total_limit=5, # limit the total amount of checkpoints and deletes the older checkpoints | |
evaluation_strategy='steps', # evaluation strategy to adopt during training | |
eval_steps=100, # number of update steps before evaluation | |
warmup_steps=500, # number of warmup steps for learning rate scheduler | |
weight_decay=0.01, # strength of weight decay | |
logging_dir='./logs', # directory for storing logs | |
logging_steps=10, | |
) | |
trainer = Trainer( | |
model=model, # the instantiated 🤗 Transformers model to be trained | |
args=training_args, # training arguments, defined above | |
train_dataset=train_dataset, # training dataset | |
eval_dataset=val_dataset # evaluation dataset | |
) | |
else: | |
training_args = TrainingArguments( | |
output_dir=output_dir, # output directory | |
num_train_epochs=2000, # total number of training epochs | |
per_device_train_batch_size=1, # batch size per device during training, can increase if memory allows | |
save_steps=500, # number of updates steps before checkpoint saves | |
save_total_limit=5, # limit the total amount of checkpoints and deletes the older checkpoints | |
warmup_steps=500, # number of warmup steps for learning rate scheduler | |
weight_decay=0.01, # strength of weight decay | |
logging_dir='./logs', # directory for storing logs | |
logging_steps=10, | |
) | |
trainer = Trainer( | |
model=model, # the instantiated 🤗 Transformers model to be trained | |
args=training_args, # training arguments, defined above | |
train_dataset=train_dataset, # training dataset | |
) | |
return trainer | |
if __name__=='__main__': | |
# use XSum dataset as example, with first 1000 docs as training data | |
from datasets import load_dataset | |
dataset = load_dataset("xsum") | |
train_texts, train_labels = dataset['train']['document'][:1000], dataset['train']['summary'][:1000] | |
# use Pegasus Large model as base for fine-tuning | |
model_name = 'google/pegasus-large' | |
train_dataset, _, _ = prepare_data(model_name, train_texts, train_labels) | |
trainer = prepare_fine_tuning(model_name, train_dataset) | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment