Skip to content

Instantly share code, notes, and snippets.

@rekalantar
Created May 30, 2023 10:10
Show Gist options
  • Select an option

  • Save rekalantar/38906a2936cbe1a534eb4fd95403a271 to your computer and use it in GitHub Desktop.

Select an option

Save rekalantar/38906a2936cbe1a534eb4fd95403a271 to your computer and use it in GitHub Desktop.
def get_bounding_box(ground_truth_map):
'''
This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
The padding is random int values between 5 and 20 pixels
'''
if len(np.unique(ground_truth_map)) > 1:
# get bounding box from mask
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(5, 20))
x_max = min(W, x_max + np.random.randint(5, 20))
y_min = max(0, y_min - np.random.randint(5, 20))
y_max = min(H, y_max + np.random.randint(5, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
else:
return [0, 0, 256, 256] # if there is no mask in the array, set bbox to image size
class SAMDataset(Dataset):
def __init__(self, image_paths, mask_paths, processor):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.processor = processor
self.transforms = transforms = Compose([
# load .nii or .nii.gz files
LoadImaged(keys=['img', 'label']),
# add channel id to match PyTorch configurations
EnsureChannelFirstd(keys=['img', 'label']),
# reorient images for consistency and visualization
Orientationd(keys=['img', 'label'], axcodes='RA'),
# resample all training images to a fixed spacing
Spacingd(keys=['img', 'label'], pixdim=(1.5, 1.5), mode=("bilinear", "nearest")),
# rescale image and label dimensions to 256x256
CenterSpatialCropd(keys=['img', 'label'], roi_size=(256,256)),
# scale intensities to 0 and 255 to match the expected input intensity range
ScaleIntensityRanged(keys=['img'], a_min=-1000, a_max=2000,
b_min=0.0, b_max=255.0, clip=True),
ScaleIntensityRanged(keys=['label'], a_min=0, a_max=255,
b_min=0.0, b_max=1.0, clip=True),
SpatialPadd(keys=["img", "label"], spatial_size=(256,256))
# RepeatChanneld(keys=['img'], repeats=3, allow_missing_keys=True)
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
mask_path = self.mask_paths[idx]
# create a dict of images and labels to apply Monai's dictionary transforms
data_dict = self.transforms({'img': image_path, 'label': mask_path})
# squeeze extra dimensions
image = data_dict['img'].squeeze()
ground_truth_mask = data_dict['label'].squeeze()
# convert to int type for huggingface's models expected inputs
image = image.astype(np.uint8)
# convert the grayscale array to RGB (3 channels)
array_rgb = np.dstack((image, image, image))
# convert to PIL image to match the expected input of processor
image_rgb = Image.fromarray(array_rgb)
# get bounding box prompt (returns xmin, ymin, xmax, ymax)
# in this dataset, the contours are -1 so we change them to 1 for label and 0 for background
ground_truth_mask[ground_truth_mask < 0] = 1
prompt = get_bounding_box(ground_truth_mask)
# prepare image and prompt for the model
inputs = self.processor(image_rgb, input_boxes=[[prompt]], return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
# add ground truth segmentation (ground truth image size is 256x256)
inputs["ground_truth_mask"] = torch.from_numpy(ground_truth_mask.astype(np.int8))
return inputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment