Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # create test dataloader | |
| test_dataset = SAMDataset(image_paths=data_paths['test_images'], mask_paths=data_paths['test_masks'], processor=processor) | |
| test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) | |
| # Iteratire through test images | |
| with torch.no_grad(): | |
| for batch in tqdm(test_dataloader): | |
| # forward pass | |
| outputs = model(pixel_values=batch["pixel_values"].cuda(), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # create test dataloader | |
| test_dataset = SAMDataset(image_paths=data_paths['test_images'], mask_paths=data_paths['test_masks'], processor=processor) | |
| test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # define training loop | |
| num_epochs = 100 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # define optimizer | |
| optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0) | |
| # define segmentation loss with sigmoid activation applied to predictions from the model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # load the pretrained weights for finetuning | |
| model = SamModel.from_pretrained("facebook/sam-vit-base") | |
| # make sure we only compute gradients for mask decoder (encoder weights are frozen) | |
| for name, param in model.named_parameters(): | |
| if name.startswith("vision_encoder") or name.startswith("prompt_encoder"): | |
| param.requires_grad_(False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| example = train_dataset[50] | |
| for k,v in example.items(): | |
| print(k,v.shape) | |
| xmin, ymin, xmax, ymax = get_bounding_box(example['ground_truth_mask']) | |
| fig, axs = plt.subplots(1, 2) | |
| axs[0].imshow(example['pixel_values'][1], cmap='gray') | |
| axs[0].axis('off') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # create train and validation dataloaders | |
| train_dataset = SAMDataset(image_paths=data_paths['train_images'], mask_paths=data_paths['train_masks'], processor=processor) | |
| train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True) | |
| val_dataset = SAMDataset(image_paths=data_paths['val_images'], mask_paths=data_paths['val_masks'], processor=processor) | |
| val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Initialize dictionary for storing image and label paths | |
| data_paths = {} | |
| # Create directories and print the number of images and masks in each | |
| for dataset in datasets: | |
| for data_type in data_types: | |
| # Construct the directory path | |
| dir_path = os.path.join(base_dir, f'{dataset}_{data_type}') | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_bounding_box(ground_truth_map): | |
| ''' | |
| This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model | |
| The padding is random int values between 5 and 20 pixels | |
| ''' | |
| if len(np.unique(ground_truth_map)) > 1: | |
| # get bounding box from mask | |
| y_indices, x_indices = np.where(ground_truth_map > 0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # create an instance of the processor for image preprocessing | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
| print(processor) |
NewerOlder