Skip to content

Instantly share code, notes, and snippets.

@philtomson
Created February 2, 2024 02:38
Show Gist options
  • Save philtomson/342e5f8330a46eb2793ffa50d90ca575 to your computer and use it in GitHub Desktop.
Save philtomson/342e5f8330a46eb2793ffa50d90ca575 to your computer and use it in GitHub Desktop.
PYTORCH TIP: mixed precision training
"""
Use torch.cuda.amp in PyTorch for mixed precision training.
This method mixes 32-bit and 16-bit data to reduce memory use and speed up model training,
without much loss in accuracy.
It takes advantage of the quick computing of 16-bit data and controls precision by handling specific operations in 32-bit.
This approach offers a balance between speed and accuracy in training models.
"""
import torch
from torch.cuda.amp import autocast, GradScaler
from torch import nn, optim
model = nn.Linear(10, 2).cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)
scaler = GradScaler()
data, target = torch.randn(5, 10).cuda(),
torch.randint(0, 2, (5,)).cuda()
optimizer.zero_grad()
with autocast():
loss = nn.functional.cross_entropy(model(data), target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment