Created
June 14, 2021 07:55
-
-
Save KeremTurgutlu/623723ffeb3399171cb7b3292aaebd45 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.all import * | |
from torch.cuda.amp import autocast, GradScaler | |
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state | |
from sam import SAM | |
class FastaiSched: | |
def __init__(self, optimizer, max_lr): | |
self.optimizer = optimizer | |
self.lr_sched = combine_scheds([0.1,0.9], [SchedLin(1e-8,max_lr), SchedCos(max_lr,1e-8)]) | |
self.update(0) | |
def update(self, pos): | |
for param_group in self.optimizer.param_groups: | |
param_group["lr"] = self.lr_sched(pos) | |
# print("lr set to:", param_group["lr"]) | |
class ProgressTracker: | |
def __init__(self, dls, total_epochs): | |
self.iter = 0 | |
self.epoch = 0 | |
self.total_iter = len(dls.train)*total_epochs | |
@property | |
def train_pct(self): return self.iter/self.total_iter | |
from torch.distributions import Beta | |
def mixup_batch(xb, yb, lam): | |
shuffle = torch.randperm(yb.size(0)).to(xb.device) | |
xb1,yb1 = tuple(L(xb).itemgot(shuffle)),tuple(L(yb).itemgot(shuffle)) | |
nx_dims = len(xb.size()) | |
xb = tuple(L(xb1,xb).map_zip(torch.lerp,weight=unsqueeze(lam, n=nx_dims-1))) | |
return xb[0],yb,yb1[0] | |
def epoch_train(model, dls, loss_fn, optimizer, scheduler, progress_tracker, grad_clip_max_norm=None, mixup=False, mixup_alpha=0.4): | |
model.train() | |
losses = [] | |
if mixup: beta_distrib = Beta(tensor(mixup_alpha), tensor(mixup_alpha)) | |
for xb, yb in progress_bar(dls.train): | |
if mixup: | |
lam = beta_distrib.sample((yb.size(0),)).squeeze().to(xb.device) | |
xb,yb,yb1 = mixup_batch(xb,yb,lam) | |
# first forward-backward pass | |
with torch.cuda.amp.autocast(): | |
out = model(xb) | |
if mixup: | |
with NoneReduce(loss_fn) as lf: | |
loss = torch.lerp(lf(pred,yb1), lf(pred,yb), lam) | |
else: | |
loss = loss_fn(out, yb) | |
loss.backward() | |
if grad_clip_max_norm: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max_norm) | |
optimizer.first_step(zero_grad=True) | |
# second forward-backward pass | |
with torch.cuda.amp.autocast(): | |
out = model(xb) | |
if mixup: | |
with NoneReduce(loss_fn) as lf: | |
loss = torch.lerp(lf(pred,yb1), lf(pred,yb), lam) | |
else: | |
loss = loss_fn(out, yb) | |
loss.backward() | |
if grad_clip_max_norm: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max_norm) | |
optimizer.second_step(zero_grad=True) | |
losses.append(to_detach(loss)) | |
# update optimizer params | |
progress_tracker.iter += 1 | |
scheduler.update(progress_tracker.train_pct) | |
return losses | |
def epoch_validate(model, loss_fn, dls): | |
model.eval() | |
preds, targs = [],[] | |
with torch.no_grad(): | |
for xb, yb in progress_bar(dls.valid): | |
pred = model(xb).cpu() | |
preds += [pred] | |
targs += [yb.cpu()] | |
preds, targs = torch.cat(preds), torch.cat(targs) | |
loss = loss_fn(preds, targs) | |
score = sklearn_mean_ap(preds.softmax(dim=1), targs) | |
return loss, score | |
def train_sam(arch, lr, epochs=15, bs=32, size=512, fastai_augs=False, cropped=False, re=False, easy=False, | |
mixup=False, save=False, debug=False, folds_range=None, ckpt=None, clip_max_norm=3.): | |
if folds_range is None: | |
if debug: folds_range = range(0,1) | |
else: folds_range = range(5) | |
if debug: | |
# overwrite during debug | |
save = False | |
epochs = 1 | |
for fold_idx in folds_range: | |
print(f"training fold {fold_idx}") | |
# dls | |
dls = get_dls(fold_idx, train_df, study_ids, study_label_ids, bs=bs, size=size, | |
fastai_augs=fastai_augs, cropped=cropped, re=re, easy=easy, debug=debug) | |
# model | |
model = get_classification_model(arch) | |
model.to(default_device()) | |
if save: | |
fname = f"{arch}" | |
fname += f"-sz{size}" | |
if re: fname += "-re" | |
if mixup: fname += "-mixup" | |
if cropped: fname += "-cropped" | |
fname += f"-SAM-fold{fold_idx}" | |
base_optimizer = torch.optim.SGD | |
optimizer = SAM(model.parameters(), base_optimizer, rho=0.05, adaptive=True, lr=lr, momentum=0.9, weight_decay=5e-4) | |
loss_fn = LabelSmoothingCrossEntropyFlat() | |
scheduler = FastaiSched(optimizer,max_lr=lr) | |
progress_tracker = ProgressTracker(dls, epochs) | |
if ckpt: | |
load_model(f"models/{ckpt}-fold{fold_idx}.pth", model, None, device=default_device()) | |
# training | |
best_score = 0 | |
score_not_improved = 0 | |
res = [] | |
for epoch in range(epochs): | |
# train and validate epoch | |
train_losses = epoch_train(model, dls, loss_fn, optimizer, scheduler, progress_tracker, | |
grad_clip_max_norm=clip_max_norm, mixup=mixup) | |
valid_loss, valid_score = epoch_validate(model, loss_fn, dls) | |
# print logs | |
train_loss = torch.stack(train_losses).mean() | |
row = [epoch, train_loss.item(), valid_loss.item(), valid_score] | |
res.append(row) | |
print(f"epoch: {row[0]} train_loss: {row[1]} valid_loss:{row[2]} valid_score:{row[3]}") | |
# save model | |
if valid_score>best_score: | |
save_model(f"models/{fname}.pth", model, None) | |
best_score = valid_score | |
else: | |
score_not_improved += 1 | |
# early stop | |
patience = 3 | |
if score_not_improved>patience: break | |
# save logs | |
res_df = pd.DataFrame(res, columns=['epoch', 'train_loss', 'valid_loss', 'sklearn_mean_ap']) | |
res_df.to_csv(history/f"{fname}.csv", index=False) | |
del model,optimizer,dls | |
gc.collect() | |
torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment