Created
May 30, 2023 10:10
-
-
Save rekalantar/38906a2936cbe1a534eb4fd95403a271 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_bounding_box(ground_truth_map): | |
| ''' | |
| This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model | |
| The padding is random int values between 5 and 20 pixels | |
| ''' | |
| if len(np.unique(ground_truth_map)) > 1: | |
| # get bounding box from mask | |
| y_indices, x_indices = np.where(ground_truth_map > 0) | |
| x_min, x_max = np.min(x_indices), np.max(x_indices) | |
| y_min, y_max = np.min(y_indices), np.max(y_indices) | |
| # add perturbation to bounding box coordinates | |
| H, W = ground_truth_map.shape | |
| x_min = max(0, x_min - np.random.randint(5, 20)) | |
| x_max = min(W, x_max + np.random.randint(5, 20)) | |
| y_min = max(0, y_min - np.random.randint(5, 20)) | |
| y_max = min(H, y_max + np.random.randint(5, 20)) | |
| bbox = [x_min, y_min, x_max, y_max] | |
| return bbox | |
| else: | |
| return [0, 0, 256, 256] # if there is no mask in the array, set bbox to image size | |
| class SAMDataset(Dataset): | |
| def __init__(self, image_paths, mask_paths, processor): | |
| self.image_paths = image_paths | |
| self.mask_paths = mask_paths | |
| self.processor = processor | |
| self.transforms = transforms = Compose([ | |
| # load .nii or .nii.gz files | |
| LoadImaged(keys=['img', 'label']), | |
| # add channel id to match PyTorch configurations | |
| EnsureChannelFirstd(keys=['img', 'label']), | |
| # reorient images for consistency and visualization | |
| Orientationd(keys=['img', 'label'], axcodes='RA'), | |
| # resample all training images to a fixed spacing | |
| Spacingd(keys=['img', 'label'], pixdim=(1.5, 1.5), mode=("bilinear", "nearest")), | |
| # rescale image and label dimensions to 256x256 | |
| CenterSpatialCropd(keys=['img', 'label'], roi_size=(256,256)), | |
| # scale intensities to 0 and 255 to match the expected input intensity range | |
| ScaleIntensityRanged(keys=['img'], a_min=-1000, a_max=2000, | |
| b_min=0.0, b_max=255.0, clip=True), | |
| ScaleIntensityRanged(keys=['label'], a_min=0, a_max=255, | |
| b_min=0.0, b_max=1.0, clip=True), | |
| SpatialPadd(keys=["img", "label"], spatial_size=(256,256)) | |
| # RepeatChanneld(keys=['img'], repeats=3, allow_missing_keys=True) | |
| ]) | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| mask_path = self.mask_paths[idx] | |
| # create a dict of images and labels to apply Monai's dictionary transforms | |
| data_dict = self.transforms({'img': image_path, 'label': mask_path}) | |
| # squeeze extra dimensions | |
| image = data_dict['img'].squeeze() | |
| ground_truth_mask = data_dict['label'].squeeze() | |
| # convert to int type for huggingface's models expected inputs | |
| image = image.astype(np.uint8) | |
| # convert the grayscale array to RGB (3 channels) | |
| array_rgb = np.dstack((image, image, image)) | |
| # convert to PIL image to match the expected input of processor | |
| image_rgb = Image.fromarray(array_rgb) | |
| # get bounding box prompt (returns xmin, ymin, xmax, ymax) | |
| # in this dataset, the contours are -1 so we change them to 1 for label and 0 for background | |
| ground_truth_mask[ground_truth_mask < 0] = 1 | |
| prompt = get_bounding_box(ground_truth_mask) | |
| # prepare image and prompt for the model | |
| inputs = self.processor(image_rgb, input_boxes=[[prompt]], return_tensors="pt") | |
| # remove batch dimension which the processor adds by default | |
| inputs = {k: v.squeeze(0) for k, v in inputs.items()} | |
| # add ground truth segmentation (ground truth image size is 256x256) | |
| inputs["ground_truth_mask"] = torch.from_numpy(ground_truth_mask.astype(np.int8)) | |
| return inputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment