Last active
September 23, 2024 20:23
-
-
Save thomwolf/ca135416a30ea387aa20edaa9b21f0ed to your computer and use it in GitHub Desktop.
A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2019-present, Thomas Wolf. | |
# All rights reserved. This source code is licensed under the MIT-style license. | |
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """ | |
import os | |
from collections import namedtuple | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from ignite.engine import Engine, Events | |
from ignite.metrics import RunningAverage | |
from ignite.handlers import ModelCheckpoint | |
from ignite.contrib.handlers import CosineAnnealingScheduler, create_lr_scheduler_with_warmup, ProgressBar | |
from pytorch_pretrained_bert import BertTokenizer, cached_path | |
class Transformer(nn.Module): | |
def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout): | |
""" Transformer (GPT-2 architecture) """ | |
super().__init__() | |
self.tokens_embeddings = nn.Embedding(num_embeddings, embed_dim) | |
self.position_embeddings = nn.Embedding(num_max_positions, embed_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList() | |
self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(), nn.ModuleList() | |
for _ in range(num_layers): | |
self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)) | |
self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, embed_dim))) | |
self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12)) | |
self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12)) | |
def forward(self, x): | |
positions = torch.arange(len(x), device=x.device).unsqueeze(-1) | |
h = self.tokens_embeddings(x) | |
h = h + self.position_embeddings(positions).expand_as(h) | |
h = self.dropout(h) | |
attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype) | |
attn_mask = torch.triu(attn_mask, diagonal=1) | |
for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions, | |
self.layer_norms_2, self.feed_forwards): | |
h = layer_norm_1(h) | |
x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False) | |
x = self.dropout(x) | |
h = x + h | |
h = layer_norm_2(h) | |
x = feed_forward(h) | |
x = self.dropout(x) | |
h = x + h | |
return h | |
class TransformerWithLMHead(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings, | |
config.num_max_positions, config.num_heads, config.num_layers, | |
config.dropout) | |
self.lm_head = nn.Linear(config.embed_dim, config.num_embeddings, bias=False) | |
self.lm_head.weight = self.transformer.tokens_embeddings.weight # Tie weights | |
self.apply(self.init_weights) | |
def init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None: | |
module.bias.data.zero_() | |
def forward(self, x, labels=None): | |
hidden_states = self.transformer(x) | |
logits = self.lm_head(hidden_states) | |
if labels is not None: | |
shift_logits = logits[:-1] | |
shift_labels = labels[1:] | |
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
return logits, loss | |
return logits | |
Config = namedtuple('Config', | |
field_names="embed_dim, hidden_dim, num_max_positions, num_embeddings, num_heads, num_layers," | |
"dropout, initializer_range, batch_size, lr, max_norm, n_epochs, n_warmup, device," | |
"gradient_accumulation_steps, log_dir, dataset_cache", | |
defaults =[410 , 2100 , 256 , 50000 , 10 , 16 , | |
0.1 , 0.02 , 16 , 2.5e-4, 0.25, 200 , 1000 , "cuda", | |
4 , "./" , "./dataset_cache_small_gist"]) | |
# Load a pre-defined tokenizer (BERT), create config and model | |
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) | |
args = Config(num_embeddings=len(tokenizer.vocab), device="cuda" if torch.cuda.is_available() else "cpu") | |
model = TransformerWithLMHead(args).to(args.device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
# Download and tokenize wikitext-103 training dataset | |
if os.path.isfile(args.dataset_cache): | |
dataset = torch.load(args.dataset_cache) | |
else: | |
dataset_file = cached_path("https://s3.amazonaws.com/datasets.huggingface.co/wikitext-103/wiki.train.tokens") | |
with open(dataset_file, "r", encoding="utf-8") as f: | |
dataset = f.readlines() | |
dataset = list(tokenizer.convert_tokens_to_ids(tokenizer.tokenize( | |
line.strip(' ').replace('\n', '[SEP]').replace('<unk>', '[UNK]'))) for line in tqdm(dataset)) | |
dataset = torch.tensor([index for line in dataset for index in line], dtype=torch.long) | |
torch.save(dataset, args.dataset_cache) | |
# Organize the dataset in blocs of num_max_positions tokens for the transformer | |
num_sequences = (dataset.size(0) // args.num_max_positions) * args.num_max_positions | |
dataset = dataset.narrow(0, 0, num_sequences).view(-1, args.num_max_positions) | |
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) | |
# Define training function | |
def update(engine, batch): | |
model.train() | |
batch = batch.transpose(0, 1).contiguous().to(args.device) # to shape [seq length, batch] | |
logits, loss = model(batch, labels=batch) | |
loss = loss / args.gradient_accumulation_steps | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) | |
if engine.state.iteration % args.gradient_accumulation_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
return loss.item() | |
trainer = Engine(update) | |
# Add progressbar with loss | |
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") | |
ProgressBar(persist=True).attach(trainer, metric_names=['loss']) | |
# Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine | |
cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, len(dataloader) * args.n_epochs) | |
scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr, args.n_warmup) | |
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) | |
# Save checkpoints and training config | |
checkpoint_handler = ModelCheckpoint(args.log_dir, 'checkpoint', save_interval=1, n_saved=5) | |
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': model}) | |
torch.save(args, os.path.join(args.log_dir, 'training_args.bin')) | |
trainer.run(dataloader, max_epochs=args.n_epochs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This gist should give a word-level perplexity of about 29 on wikitext-103 validation dataset when training 15h on 8 v100 (a few days on a single GPU). To get word-level perplexity you need to convert sub-word nll in word-level perplexity (see here for details on the conversion process).
A few words: