This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Ralamb(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
self.buffer = [[None, None, None] for ind in range(10)] | |
super(Ralamb, self).__init__(params, defaults) | |
def __setstate__(self, state): | |
super(Ralamb, self).__setstate__(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# fastai integration of Accelerate | |
from accelerate import Accelerator | |
from fastai.callback.core import Callback, CancelBatchException, CancelStepException | |
from fastai.learner import Learner, Metric | |
from fastai.metrics import AccumMetric | |
from fastai.optimizer import Optimizer, _update | |
from fastai.distributed import DistributedDL | |
from fastai.torch_core import to_device |