This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Use torch.cuda.amp in PyTorch for mixed precision training. | |
This method mixes 32-bit and 16-bit data to reduce memory use and speed up model training, | |
without much loss in accuracy. | |
It takes advantage of the quick computing of 16-bit data and controls precision by handling specific operations in 32-bit. | |
This approach offers a balance between speed and accuracy in training models. | |
""" | |
import torch | |
from torch.cuda.amp import autocast, GradScaler |
OlderNewer