Skip to content

Instantly share code, notes, and snippets.

@tallesairan
Forked from NaxAlpha/pythia_1b4_8k.py
Created August 11, 2023 14:19

Revisions

  1. @NaxAlpha NaxAlpha revised this gist Apr 22, 2023. 1 changed file with 1 addition and 15 deletions.
    16 changes: 1 addition & 15 deletions pythia_1b4_8k.py
    Original file line number Diff line number Diff line change
    @@ -13,20 +13,6 @@
    from transformers import AutoTokenizer, GPTNeoXForCausalLM
    from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention

    WANDB_STYLE = """
    <style>
    html, body {
    padding: 0;
    margin: 0;
    width: 100%;
    height: 100%;
    }
    p {
    font-family: 'Verdana', sans-serif;
    }
    </style>
    """


    def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None):
    assert attention_mask is None and head_mask is None, "Not implemented"
    @@ -39,7 +25,7 @@ def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None):
    ).float()
    return out, None


    # patch attention to save a lot of memory
    GPTNeoXAttention._attn = _attn_wrapper


  2. @NaxAlpha NaxAlpha created this gist Apr 22, 2023.
    158 changes: 158 additions & 0 deletions pythia_1b4_8k.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,158 @@
    import copy

    import torch
    import torch.nn.functional as F
    import torch.backends.cuda as cuda
    from torch.utils.data import DataLoader, IterableDataset

    import wandb
    from tqdm import tqdm
    import bitsandbytes as bnb

    from datasets import load_dataset
    from transformers import AutoTokenizer, GPTNeoXForCausalLM
    from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention

    WANDB_STYLE = """
    <style>
    html, body {
    padding: 0;
    margin: 0;
    width: 100%;
    height: 100%;
    }
    p {
    font-family: 'Verdana', sans-serif;
    }
    </style>
    """


    def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None):
    assert attention_mask is None and head_mask is None, "Not implemented"
    with cuda.sdp_kernel(enable_math=False):
    out = F.scaled_dot_product_attention(
    query.half(),
    key.half(),
    value.half(),
    is_causal=True,
    ).float()
    return out, None


    GPTNeoXAttention._attn = _attn_wrapper


    class DatasetWrapper(IterableDataset):
    def __init__(self, max_tokens):
    self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
    self.max_tokens = max_tokens

    def __iter__(self):
    buffer = []
    for sample in load_dataset(
    "EleutherAI/the_pile_deduplicated",
    # "togethercomputer/RedPajama-Data-1T",
    name="all",
    split="train",
    streaming=True,
    ).shuffle(buffer_size=10_000):
    buffer += self.tokenizer(sample["text"])["input_ids"]
    buffer += [self.tokenizer.eos_token_id]
    while len(buffer) > self.max_tokens:
    yield torch.tensor(buffer[: self.max_tokens])
    buffer = buffer[self.max_tokens :]


    class Trainer:
    def __init__(self):
    self.max_tokens = 2**13
    self.grad = 64
    self.step = 0

    self.dataset = DatasetWrapper(self.max_tokens)
    self.tokenizer = self.dataset.tokenizer
    self.loader = DataLoader(
    self.dataset,
    batch_size=1,
    num_workers=8,
    )
    self.scaler = torch.cuda.amp.GradScaler()
    self.model = model = GPTNeoXForCausalLM.from_pretrained(
    "EleutherAI/pythia-1.4b-deduped",
    ).cuda()

    self.show_params()

    self.opt = bnb.optim.Lion(
    params=model.parameters(),
    lr=1e-5,
    weight_decay=1e-1,
    betas=(0.9, 0.95),
    optim_bits=8,
    # fused=True,
    )
    self.model = torch.compile(model)

    def show_params(self):
    model = self.model
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    emb_params = list(model.gpt_neox.embed_in.parameters())
    emb_params += list(model.embed_out.parameters())
    emb_params = sum(p.numel() for p in emb_params if p.requires_grad)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Params:", params - emb_params)
    print("Params (incl. embeddings):", params)
    print("Trainable params:", trainable_params)

    def train_step(self, batch):
    batch = batch.cuda()
    x, y = batch[:, :-1], batch[:, 1:]
    with torch.autocast(device_type="cuda", enabled=True):
    z = self.model(x).logits
    y = y.reshape(-1)
    z = z.view(-1, z.shape[-1])
    loss = F.cross_entropy(z, y)
    self.scaler.scale(loss / self.grad).backward()
    return loss

    def train(self):
    wandb.init(
    project="pythia",
    entity="<your-wandb-id>",
    )

    prog = tqdm(self.loader)
    self.opt.zero_grad()

    for i, batch in enumerate(prog):
    self.step = i + 1

    loss = self.train_step(batch)
    prog.set_description(f"loss: {loss.item():.3f}")
    wandb.log(
    {
    "loss": loss.item(),
    },
    step=i,
    )

    if (i + 1) % self.grad == 0:
    self.scaler.step(self.opt)
    self.scaler.update()
    self.opt.zero_grad()

    if i % 1000 == 0:
    temp_model = copy.deepcopy(self.model).half()
    temp_model.save_pretrained(
    "<your-hf-repo-id>",
    push_to_hub=True,
    max_shard_size="500MB",
    )
    del temp_model
    torch.cuda.empty_cache()


    if __name__ == "__main__":
    trainer = Trainer()
    trainer.train()