Last active
July 23, 2024 13:07
-
-
Save NaxAlpha/1c36eaddd03ed102d24372493264694c to your computer and use it in GitHub Desktop.
Training script for LongGPT; Fine-tunes GPT-2 (335M) on The Pile Dataset with a context size of 8k tokens. (requires > 16GB RAM)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from contextlib import suppress | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torch.backends.cuda as cuda | |
from torch.utils.data import DataLoader, IterableDataset | |
import wandb | |
from tqdm import tqdm | |
from datasets import load_dataset | |
from transformers import GPT2TokenizerFast | |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention | |
_attn_orig = GPT2Attention._attn | |
WANDB_STYLE = """ | |
<style> | |
html, body { | |
padding: 0; | |
margin: 0; | |
width: 100%; | |
height: 100%; | |
} | |
p { | |
font-family: 'Verdana', sans-serif; | |
} | |
</style> | |
""" | |
# patch GPT2Attention to use flash_sdp, disable it when doing the inference | |
def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None): | |
if head_mask is not None: | |
raise NotImplementedError("head_mask is not implemented for flash_sdp") | |
is_causal = attention_mask is None | |
with cuda.sdp_kernel( | |
enable_flash=True, | |
enable_math=False, | |
enable_mem_efficient=False, | |
): | |
attn_out = F.scaled_dot_product_attention( | |
query=query.half(), | |
key=key.half(), | |
value=value.half(), | |
is_causal=is_causal, | |
attn_mask=attention_mask, | |
dropout_p=self.attn_dropout.p, | |
).float() | |
return attn_out, None | |
def closest_power_of_2(x): | |
return 2 ** (x - 1).bit_length() | |
def make_model(pretrained_name, max_tokens): | |
model = GPT2LMHeadModel.from_pretrained(pretrained_name).cuda() | |
GPT2Attention._attn = _attn_wrapper | |
model.config.update( | |
dict( | |
n_ctx=max_tokens, | |
n_positions=max_tokens, | |
) | |
) | |
# patch model embeddings | |
emb = model.transformer.wpe.weight.data | |
wpe = nn.Embedding(max_tokens, emb.shape[1]) | |
wpe.weight.data = emb.repeat(max_tokens // emb.shape[0], 1) | |
model.transformer.wpe = wpe | |
# also increase mask size | |
for block in model.transformer.h: | |
block.attn.bias.data = ( | |
torch.tril(torch.ones((max_tokens, max_tokens), dtype=torch.bool)) | |
.view(1, 1, max_tokens, max_tokens) | |
.cuda() | |
) | |
return model | |
class DatasetWrapper(IterableDataset): | |
def __init__(self, max_tokens=2**12): | |
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
self.max_tokens = max_tokens | |
def __iter__(self): | |
buffer = [] | |
for sample in load_dataset( | |
"the_pile", | |
name="all", | |
split="train", | |
streaming=True, | |
).shuffle(buffer_size=10_000): | |
buffer += self.tokenizer(sample["text"])["input_ids"] | |
buffer += [self.tokenizer.eos_token_id] | |
while len(buffer) > self.max_tokens: | |
yield torch.tensor(buffer[: self.max_tokens]) | |
buffer = buffer[self.max_tokens :] | |
class Trainer: | |
def __init__(self): | |
self.max_tokens = 2**13 | |
self.grad = 1 | |
self.step = 0 | |
self.dataset = DatasetWrapper(self.max_tokens) | |
self.tokenizer = self.dataset.tokenizer | |
self.loader = DataLoader( | |
self.dataset, | |
batch_size=1, | |
num_workers=8, | |
) | |
self.scaler = torch.cuda.amp.GradScaler() | |
self.model = model = make_model("gpt2-medium", self.max_tokens) | |
self.opt = optim.Adam( | |
params=model.parameters(), | |
lr=5e-6, | |
weight_decay=1e-1, | |
betas=(0.9, 0.95), | |
fused=True, | |
) | |
self.model = torch.compile(model) | |
def train_step(self, batch): | |
batch = batch.cuda() | |
with torch.autocast(device_type="cuda", enabled=True): | |
loss = self.model(batch, labels=batch).loss | |
loss = loss / self.grad | |
self.scaler.scale(loss).backward() | |
return loss | |
def generate_samples(self, n_samples=8): | |
GPT2Attention._attn = _attn_orig # back to faster but more memory consuming | |
model = self.model | |
x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda() | |
t0 = time.time() | |
model.eval() | |
y = model.generate( | |
inputs=x, | |
max_length=self.max_tokens, | |
do_sample=True, | |
).tolist() | |
model.train() | |
t1 = time.time() | |
t = [self.tokenizer.decode(z) for z in y] | |
t = "<hr>".join(f"<p>{c}</p>" for c in t) | |
html = WANDB_STYLE + t | |
wandb.log({"samples": wandb.Html(html)}, step=self.step) | |
print(f"Generated in {t1-t0:.3f}s") | |
GPT2Attention._attn = _attn_wrapper | |
def train(self): | |
wandb.init( | |
project="long-gptx", | |
entity="_", | |
) | |
prog = tqdm(self.loader) | |
self.opt.zero_grad() | |
for i, batch in enumerate(prog): | |
self.step = i + 1 | |
loss = self.train_step(batch) | |
prog.set_description(f"loss: {loss.item():.3f}") | |
wandb.log( | |
{ | |
"loss": loss.item(), | |
"grad": self.grad, | |
}, | |
step=i, | |
) | |
if i % self.grad == 0: | |
self.scaler.step(self.opt) | |
self.scaler.update() | |
self.opt.zero_grad() | |
self.grad = max(1, closest_power_of_2(i + 1) // 32) | |
# if i % 1000 == 0: | |
# with suppress(Exception): | |
# self.model.save_pretrained( | |
# "_", | |
# push_to_hub=True, | |
# max_shard_size="500MB", | |
# ) | |
if i % 1000 == 0: | |
self.generate_samples(16) | |
if __name__ == "__main__": | |
trainer = Trainer() | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment