Skip to content

Instantly share code, notes, and snippets.

@rekalantar
Last active May 30, 2023 10:30
Show Gist options
  • Select an option

  • Save rekalantar/0cef99a871fafd3f01d66a656590d025 to your computer and use it in GitHub Desktop.

Select an option

Save rekalantar/0cef99a871fafd3f01d66a656590d025 to your computer and use it in GitHub Desktop.
# define training loop
num_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# define optimizer
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
# define segmentation loss with sigmoid activation applied to predictions from the model
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
# track mean train and validation losses
mean_train_losses, mean_val_losses = [], []
# create an artibarily large starting validation loss value
best_val_loss = 100.0
best_val_epoch = 0
# set model to train mode for gradient updating
model.train()
for epoch in range(num_epochs):
# create temporary list to record training losses
epoch_losses = []
for i, batch in enumerate(tqdm(train_dataloader)):
# forward pass
outputs = model(pixel_values=batch["pixel_values"].to(device),
input_boxes=batch["input_boxes"].to(device),
multimask_output=False)
# compute loss
predicted_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().to(device)
loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
# backward pass (compute gradients of parameters w.r.t. loss)
optimizer.zero_grad()
loss.backward()
# optimize
optimizer.step()
epoch_losses.append(loss.item())
# visualize training predictions every 50 iterations
if i % 50 == 0:
# clear jupyter cell output
clear_output(wait=True)
fig, axs = plt.subplots(1, 3)
xmin, ymin, xmax, ymax = get_bounding_box(batch['ground_truth_mask'][0])
rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')
axs[0].set_title('input image')
axs[0].imshow(batch["pixel_values"][0,1], cmap='gray')
axs[0].axis('off')
axs[1].set_title('ground truth mask')
axs[1].imshow(batch['ground_truth_mask'][0], cmap='copper')
axs[1].add_patch(rect)
axs[1].axis('off')
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
axs[2].set_title('predicted mask')
axs[2].imshow(medsam_seg, cmap='copper')
axs[2].axis('off')
plt.tight_layout()
plt.show()
# create temporary list to record validation losses
val_losses = []
# set model to eval mode for validation
with torch.no_grad():
for val_batch in tqdm(val_dataloader):
# forward pass
outputs = model(pixel_values=val_batch["pixel_values"].to(device),
input_boxes=val_batch["input_boxes"].to(device),
multimask_output=False)
# calculate val loss
predicted_val_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().to(device)
val_loss = seg_loss(predicted_val_masks, ground_truth_masks.unsqueeze(1))
val_losses.append(val_loss.item())
# visualize the last validation prediction
fig, axs = plt.subplots(1, 3)
xmin, ymin, xmax, ymax = get_bounding_box(val_batch['ground_truth_mask'][0])
rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')
axs[0].set_title('input image')
axs[0].imshow(val_batch["pixel_values"][0,1], cmap='gray')
axs[0].axis('off')
axs[1].set_title('ground truth mask')
axs[1].imshow(val_batch['ground_truth_mask'][0], cmap='copper')
axs[1].add_patch(rect)
axs[1].axis('off')
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
axs[2].set_title('predicted mask')
axs[2].imshow(medsam_seg, cmap='copper')
axs[2].axis('off')
plt.tight_layout()
plt.show()
# save the best weights and record the best performing epoch
if mean(val_losses) < best_val_loss:
torch.save(model.state_dict(), f"best_weights.pth")
print(f"Model Was Saved! Current Best val loss {best_val_loss}")
best_val_loss = mean(val_losses)
best_val_epoch = epoch
else:
print("Model Was Not Saved!")
print(f'EPOCH: {epoch}')
print(f'Mean loss: {mean(epoch_losses)}')
mean_train_losses.append(mean(epoch_losses))
mean_val_losses.append(mean(val_losses))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment