Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Created July 22, 2022 12:20
Show Gist options
  • Save AlessandroMondin/8619c1803d9fc72d07557d06bf2d199b to your computer and use it in GitHub Desktop.
Save AlessandroMondin/8619c1803d9fc72d07557d06bf2d199b to your computer and use it in GitHub Desktop.
def main():
loss_fn = torch.nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()
model = UNET(3, 64, 1, padding=0, downhill=4).to(DEVICE)
optim = Adam(model.parameters(), lr=LEARNING_RATE)
if CHECKPOINT:
load_model_checkpoint(CHECKPOINT, model)
load_optim_checkpoint(CHECKPOINT, optim)
train_loader, val_loader = get_loaders(db_root_dir=ROOT_DIR, batch_size=8, train_transform=train_transform,
val_transform=val_transforms, num_workers=4)
for epoch in range(10, EPOCHS):
print(f"Training epoch {epoch+1}/{EPOCHS}")
train_loop(model=model, loader=train_loader, loss_fn=loss_fn, optim=optim, scaler=scaler, pos_weight=False)
print("Computing dice_loss on val_loader...")
evalution_metrics(model, val_loader, loss_fn, device=DEVICE)
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
}
save_checkpoint(checkpoint, folder_path=SAVE_MODEL_PATH,
filename=f"checkpoint_epoch_{epoch+1}.pth.tar")
save_images(model=model, loader=val_loader, folder=SAVE_IMAGES_PATH,
epoch=epoch, device=DEVICE, num_images=10, pad_mirroring=PAD_MIRRORING)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment