Last active
July 26, 2022 03:10
-
-
Save KeremTurgutlu/4ec36c40843cbbef0710e1d6e1c83151 to your computer and use it in GitHub Desktop.
EMA and SWA callbacks for different model averaging techniques
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.all import * | |
__all__ = ["EMA", "SWA"] | |
class EMA(Callback): | |
"https://fastai.github.io/timmdocs/training_modelEMA" | |
order,run_valid=5,False | |
def __init__(self, decay=0.9999): | |
super().__init__() | |
self.decay = decay | |
self.switched = False | |
def before_fit(self): | |
if not hasattr(self, "ema_model"): | |
print("Init EMA model") | |
self.ema_model = deepcopy(self.learn.model) | |
for param_k in self.ema_model.parameters(): | |
param_k.requires_grad = False | |
@torch.no_grad() | |
def _update(self): | |
for param_k, param_q in zip(self.ema_model.parameters(), self.learn.model.parameters()): | |
param_k.data = param_k.data * self.decay + param_q.data * (1. - self.decay) | |
def after_step(self): | |
"Momentum update target model" | |
self._update() | |
def switch_model(self): | |
if self.switched: | |
self.learn.model = self.original_model | |
self.switched = False | |
print("Switched to original model") | |
else: | |
self.original_model = self.learn.model | |
self.learn.model = self.ema_model | |
self.switched = True | |
print("Switched to EMA model") | |
class SWA(Callback): | |
"https://arxiv.org/pdf/1803.05407.pdf (Use with fit_sgdr_*)" | |
order,run_valid=5,False | |
def __init__(self, pcts:List[float], swa_start=0): | |
""" | |
pcts: pcts of end of each cycle in terms of pct_train | |
swa_start: at which cycle end to start averaging | |
""" | |
super().__init__() | |
self.swa_start = swa_start | |
self.pcts = pcts | |
self.switched = False | |
self.swa_n = 0 | |
self.curr_pct_idx = 0 | |
self.curr_pct = self.pcts[self.curr_pct_idx] | |
def before_fit(self): | |
print("training mode:",self.learn.training) | |
if not hasattr(self, "swa_model"): | |
print("Init SWA model") | |
self.swa_model = deepcopy(self.learn.model) | |
for param_k in self.swa_model.parameters(): | |
param_k.requires_grad = False | |
def after_step(self): | |
"Update SWA model at given pcts" | |
if (self.pct_train >= self.curr_pct): | |
print(f"Updating swa model at pct_train: {self.pct_train}") | |
self.update_average_model() | |
self.curr_pct_idx += 1 | |
self.curr_pct = self.pcts[self.curr_pct_idx] | |
def after_fit(self): | |
"Average final checkpoint if it wasn't" | |
if (np.round(self.pct_train, 2) >= self.curr_pct) and (self.curr_pct_idx == len(self.pcts)-1): | |
print(f"Updating final swa model at pct_train: {self.pct_train}") | |
self.update_average_model() | |
self.curr_pct_idx += 1 | |
def update_average_model(self): | |
# update running average of parameters | |
for model_param, swa_param in zip(self.model.parameters(), self.swa_model.parameters()): | |
swa_param.data = (swa_param.data*self.swa_n + model_param.data) / (self.swa_n + 1) | |
self.swa_n += 1 | |
def switch_model(self): | |
if self.switched: | |
self.learn.model = self.original_model | |
self.switched = False | |
print("Switched to original model") | |
else: | |
self.original_model = self.learn.model | |
self.learn.model = self.swa_model | |
self.switched = True | |
print("Switched to SWA model") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment