Skip to content

Instantly share code, notes, and snippets.

@rekalantar
Last active May 30, 2023 10:33
Show Gist options
  • Select an option

  • Save rekalantar/e544bdb407948f9188a35c07ce93cfa6 to your computer and use it in GitHub Desktop.

Select an option

Save rekalantar/e544bdb407948f9188a35c07ce93cfa6 to your computer and use it in GitHub Desktop.
# create test dataloader
test_dataset = SAMDataset(image_paths=data_paths['test_images'], mask_paths=data_paths['test_masks'], processor=processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# Iteratire through test images
with torch.no_grad():
for batch in tqdm(test_dataloader):
# forward pass
outputs = model(pixel_values=batch["pixel_values"].cuda(),
input_boxes=batch["input_boxes"].cuda(),
multimask_output=False)
# compute loss
predicted_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().cuda()
# loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(batch["pixel_values"][0,1], cmap='gray')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(batch["ground_truth_mask"][0], cmap='copper')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(medsam_seg, cmap='copper')
plt.axis('off')
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment