Created
February 2, 2024 02:38
-
-
Save philtomson/342e5f8330a46eb2793ffa50d90ca575 to your computer and use it in GitHub Desktop.
PYTORCH TIP: mixed precision training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Use torch.cuda.amp in PyTorch for mixed precision training. | |
This method mixes 32-bit and 16-bit data to reduce memory use and speed up model training, | |
without much loss in accuracy. | |
It takes advantage of the quick computing of 16-bit data and controls precision by handling specific operations in 32-bit. | |
This approach offers a balance between speed and accuracy in training models. | |
""" | |
import torch | |
from torch.cuda.amp import autocast, GradScaler | |
from torch import nn, optim | |
model = nn.Linear(10, 2).cuda() | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
scaler = GradScaler() | |
data, target = torch.randn(5, 10).cuda(), | |
torch.randint(0, 2, (5,)).cuda() | |
optimizer.zero_grad() | |
with autocast(): | |
loss = nn.functional.cross_entropy(model(data), target) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment