Last active
May 30, 2023 10:30
-
-
Save rekalantar/0cef99a871fafd3f01d66a656590d025 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # define training loop | |
| num_epochs = 100 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # define optimizer | |
| optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0) | |
| # define segmentation loss with sigmoid activation applied to predictions from the model | |
| seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') | |
| # track mean train and validation losses | |
| mean_train_losses, mean_val_losses = [], [] | |
| # create an artibarily large starting validation loss value | |
| best_val_loss = 100.0 | |
| best_val_epoch = 0 | |
| # set model to train mode for gradient updating | |
| model.train() | |
| for epoch in range(num_epochs): | |
| # create temporary list to record training losses | |
| epoch_losses = [] | |
| for i, batch in enumerate(tqdm(train_dataloader)): | |
| # forward pass | |
| outputs = model(pixel_values=batch["pixel_values"].to(device), | |
| input_boxes=batch["input_boxes"].to(device), | |
| multimask_output=False) | |
| # compute loss | |
| predicted_masks = outputs.pred_masks.squeeze(1) | |
| ground_truth_masks = batch["ground_truth_mask"].float().to(device) | |
| loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1)) | |
| # backward pass (compute gradients of parameters w.r.t. loss) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| # optimize | |
| optimizer.step() | |
| epoch_losses.append(loss.item()) | |
| # visualize training predictions every 50 iterations | |
| if i % 50 == 0: | |
| # clear jupyter cell output | |
| clear_output(wait=True) | |
| fig, axs = plt.subplots(1, 3) | |
| xmin, ymin, xmax, ymax = get_bounding_box(batch['ground_truth_mask'][0]) | |
| rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none') | |
| axs[0].set_title('input image') | |
| axs[0].imshow(batch["pixel_values"][0,1], cmap='gray') | |
| axs[0].axis('off') | |
| axs[1].set_title('ground truth mask') | |
| axs[1].imshow(batch['ground_truth_mask'][0], cmap='copper') | |
| axs[1].add_patch(rect) | |
| axs[1].axis('off') | |
| # apply sigmoid | |
| medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
| # convert soft mask to hard mask | |
| medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze() | |
| medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) | |
| axs[2].set_title('predicted mask') | |
| axs[2].imshow(medsam_seg, cmap='copper') | |
| axs[2].axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| # create temporary list to record validation losses | |
| val_losses = [] | |
| # set model to eval mode for validation | |
| with torch.no_grad(): | |
| for val_batch in tqdm(val_dataloader): | |
| # forward pass | |
| outputs = model(pixel_values=val_batch["pixel_values"].to(device), | |
| input_boxes=val_batch["input_boxes"].to(device), | |
| multimask_output=False) | |
| # calculate val loss | |
| predicted_val_masks = outputs.pred_masks.squeeze(1) | |
| ground_truth_masks = batch["ground_truth_mask"].float().to(device) | |
| val_loss = seg_loss(predicted_val_masks, ground_truth_masks.unsqueeze(1)) | |
| val_losses.append(val_loss.item()) | |
| # visualize the last validation prediction | |
| fig, axs = plt.subplots(1, 3) | |
| xmin, ymin, xmax, ymax = get_bounding_box(val_batch['ground_truth_mask'][0]) | |
| rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none') | |
| axs[0].set_title('input image') | |
| axs[0].imshow(val_batch["pixel_values"][0,1], cmap='gray') | |
| axs[0].axis('off') | |
| axs[1].set_title('ground truth mask') | |
| axs[1].imshow(val_batch['ground_truth_mask'][0], cmap='copper') | |
| axs[1].add_patch(rect) | |
| axs[1].axis('off') | |
| # apply sigmoid | |
| medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
| # convert soft mask to hard mask | |
| medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze() | |
| medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) | |
| axs[2].set_title('predicted mask') | |
| axs[2].imshow(medsam_seg, cmap='copper') | |
| axs[2].axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| # save the best weights and record the best performing epoch | |
| if mean(val_losses) < best_val_loss: | |
| torch.save(model.state_dict(), f"best_weights.pth") | |
| print(f"Model Was Saved! Current Best val loss {best_val_loss}") | |
| best_val_loss = mean(val_losses) | |
| best_val_epoch = epoch | |
| else: | |
| print("Model Was Not Saved!") | |
| print(f'EPOCH: {epoch}') | |
| print(f'Mean loss: {mean(epoch_losses)}') | |
| mean_train_losses.append(mean(epoch_losses)) | |
| mean_val_losses.append(mean(val_losses)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment