Last active
January 27, 2025 11:26
-
-
Save wassname/183153f9245b37ae6d08b3c3c4033bda to your computer and use it in GitHub Desktop.
Using 16 bit base weights in huggingface transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
You can save memory by converting your model to bf16, but ONLY if you use a special optimiser. Otherwise you round away small changes and get worse result. | |
This is how to use a brainfloat16 base model in huggingface transformers | |
@author:wassname | |
@url: https://gist.github.com/wassname/183153f9245b37ae6d08b3c3c4033bda | |
Usage: | |
model = AutoModelForCausalLM.from_pretrained('gpt2') | |
convert_to_bfloat16(model) # or just model.to(torch.bfloat16) will probobly work | |
args = training_args = TrainingArguments(bf16=configs.bf16, bf16_full_eval=configs.bf16) | |
trainer = TrainerBF16(model, args) | |
trainer.train() | |
""" | |
from transformers import Trainer | |
from optimi import AdamW | |
from transformers import TrainingArguments, PreTrainedModel | |
from typing import Any, Optional, Tuple | |
from torch import nn | |
import torch | |
class TrainerBF16(Trainer): | |
@staticmethod | |
def get_optimizer_cls_and_kwargs( | |
args: TrainingArguments, model: Optional[PreTrainedModel] = None | |
) -> Tuple[Any, Any]: | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args, model) | |
allowed_kwargs = {'lr', 'betas', 'weight_decay', 'eps', 'decouple_lr', 'max_lr', 'kahan_sum', 'foreach', 'gradient_release'} | |
default_kwargs = {'kahan_sum': True} | |
optimizer_kwargs = {k: v for k, v in optimizer_kwargs.items() if k in allowed_kwargs} | |
optimizer_kwargs = {**default_kwargs, **optimizer_kwargs} | |
return AdamW, optimizer_kwargs | |
def convert_to_bfloat16(module): | |
""" | |
Some modules like layernorm or embeddings might need higher precision? Although the latest papers question this | |
""" | |
for child in module.children(): | |
if isinstance(child, (nn.Linear, nn.Conv2d)): | |
child.to(torch.bfloat16) | |
else: | |
convert_to_bfloat16(child) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Links
This paper shows that you can use bf16 if you use stochastic rounding or Kahan summation
The below paper doesn't manage to use full bf16 on the model weight, not saving much GPU memory
This reddit post has a usefull explination