Skip to content

Instantly share code, notes, and snippets.

View rekalantar's full-sized avatar

Reza Kalantar rekalantar

View GitHub Profile
@rekalantar
rekalantar / medsegmentanything_finetuning.ipynb
Created May 30, 2023 10:43
MedSegmentAnything_FineTuning.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# create test dataloader
test_dataset = SAMDataset(image_paths=data_paths['test_images'], mask_paths=data_paths['test_masks'], processor=processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# Iteratire through test images
with torch.no_grad():
for batch in tqdm(test_dataloader):
# forward pass
outputs = model(pixel_values=batch["pixel_values"].cuda(),
# create test dataloader
test_dataset = SAMDataset(image_paths=data_paths['test_images'], mask_paths=data_paths['test_masks'], processor=processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# define training loop
num_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# define optimizer
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
# define segmentation loss with sigmoid activation applied to predictions from the model
# load the pretrained weights for finetuning
model = SamModel.from_pretrained("facebook/sam-vit-base")
# make sure we only compute gradients for mask decoder (encoder weights are frozen)
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)
example = train_dataset[50]
for k,v in example.items():
print(k,v.shape)
xmin, ymin, xmax, ymax = get_bounding_box(example['ground_truth_mask'])
fig, axs = plt.subplots(1, 2)
axs[0].imshow(example['pixel_values'][1], cmap='gray')
axs[0].axis('off')
# create train and validation dataloaders
train_dataset = SAMDataset(image_paths=data_paths['train_images'], mask_paths=data_paths['train_masks'], processor=processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataset = SAMDataset(image_paths=data_paths['val_images'], mask_paths=data_paths['val_masks'], processor=processor)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# Initialize dictionary for storing image and label paths
data_paths = {}
# Create directories and print the number of images and masks in each
for dataset in datasets:
for data_type in data_types:
# Construct the directory path
dir_path = os.path.join(base_dir, f'{dataset}_{data_type}')
def get_bounding_box(ground_truth_map):
'''
This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
The padding is random int values between 5 and 20 pixels
'''
if len(np.unique(ground_truth_map)) > 1:
# get bounding box from mask
y_indices, x_indices = np.where(ground_truth_map > 0)
# create an instance of the processor for image preprocessing
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
print(processor)