Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active January 27, 2025 11:26
Show Gist options
  • Save wassname/183153f9245b37ae6d08b3c3c4033bda to your computer and use it in GitHub Desktop.
Save wassname/183153f9245b37ae6d08b3c3c4033bda to your computer and use it in GitHub Desktop.
Using 16 bit base weights in huggingface transformers
"""
You can save memory by converting your model to bf16, but ONLY if you use a special optimiser. Otherwise you round away small changes and get worse result.
This is how to use a brainfloat16 base model in huggingface transformers
@author:wassname
@url: https://gist.github.com/wassname/183153f9245b37ae6d08b3c3c4033bda
Usage:
model = AutoModelForCausalLM.from_pretrained('gpt2')
convert_to_bfloat16(model) # or just model.to(torch.bfloat16) will probobly work
args = training_args = TrainingArguments(bf16=configs.bf16, bf16_full_eval=configs.bf16)
trainer = TrainerBF16(model, args)
trainer.train()
"""
from transformers import Trainer
from optimi import AdamW
from transformers import TrainingArguments, PreTrainedModel
from typing import Any, Optional, Tuple
from torch import nn
import torch
class TrainerBF16(Trainer):
@staticmethod
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args, model)
allowed_kwargs = {'lr', 'betas', 'weight_decay', 'eps', 'decouple_lr', 'max_lr', 'kahan_sum', 'foreach', 'gradient_release'}
default_kwargs = {'kahan_sum': True}
optimizer_kwargs = {k: v for k, v in optimizer_kwargs.items() if k in allowed_kwargs}
optimizer_kwargs = {**default_kwargs, **optimizer_kwargs}
return AdamW, optimizer_kwargs
def convert_to_bfloat16(module):
"""
Some modules like layernorm or embeddings might need higher precision? Although the latest papers question this
"""
for child in module.children():
if isinstance(child, (nn.Linear, nn.Conv2d)):
child.to(torch.bfloat16)
else:
convert_to_bfloat16(child)
@wassname
Copy link
Author

wassname commented Jan 27, 2025

Links

This paper shows that you can use bf16 if you use stochastic rounding or Kahan summation

Unfortunately, we show empirically that standard pure 16-bit training does not match 32-bit training on model accuracy across deep learning models. For example, the stan- dard pure 16-bit training algorithm one would run on conventional hardware attains 16% and 7% lower training and validation accuracies than a 32-bit baseline.
Motivated by this, we identify two simple existing techniques, stochastic rounding and Kahan summation, to remedy the model accuracy degradation in pure 16-bit training. We empirically show that these two techniques can enable up to 7% absolute validation accuracy gain in pure 16-bit training. This leads to 0.1% lower to 0.2% higher matching validation accuracy compared to 32-bit precision training across seven deep learning applications.
Stochastic arithmetic lets you perform the addition in such a way that the weights have a non-zero probability of being modified anyway. This avoids the stagnation problem

The below paper doesn't manage to use full bf16 on the model weight, not saving much GPU memory

Our results show that deep learning training using BFLOAT16 tensors achieves the same state-of-the-art (SOTA) results across domains as FP32 tensors in the same number of iterations and with no changes to hyper-parameters.
our experiments all the input tensors (activations, weights) are converted to BFLOAT16 for convolution layers (in both the generator and the discriminator), while only the input activations are converted to BFLOAT16 for batch normalization layers; all other tensors are maintained in full precision.
src:"study of bfloat16 for deep learning training."

This reddit post has a usefull explination

But unfortunately your real problem is numerics. BF16 master weights will be nearly impossible to train because the small mantissa (7 bits) means that any weight update smaller than 1/128 of the weight value is lost. In other words anytime (dW / W) < 1/128, the update is floored to zero. This effect is element-wise.
src: https://old.reddit.com/r/MachineLearning/comments/j2r4kn/d_can_you_train_models_purely_in_bf16/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment