Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Created May 1, 2019 17:21
Show Gist options
  • Save sshleifer/77314de11dba384d8a55acb10254c323 to your computer and use it in GitHub Desktop.
Save sshleifer/77314de11dba384d8a55acb10254c323 to your computer and use it in GitHub Desktop.
Script for going from mask to bboxes (bbox branch)
import numpy as np
import pandas as pd
import pickle as pkl
import nrrd
import glob
import os
import sys
def find_bounding_box(mask, point, label):
visited = set()
min_point = list(point)
max_point = list(point)
find_bounding_box_helper(mask, point, label, visited, min_point, max_point)
size = tuple(max_point[i] - min_point[i] for i in range(len(point)))
return (min_point, size), visited
def find_bounding_box_helper(mask, point, label, visited, min_point, max_point):
if mask[point] == label and point not in visited:
visited.add(point)
for axis in range(len(point)):
min_point[axis] = min(min_point[axis], point[axis])
max_point[axis] = max(max_point[axis], point[axis])
find_bounding_box_helper(mask, point[:axis] + (point[axis] + 1,) + point[axis + 1:], label, visited, min_point, max_point)
find_bounding_box_helper(mask, point[:axis] + (point[axis] - 1,) + point[axis + 1:], label, visited, min_point, max_point)
def generate_boxes(mask):
"""
Generates bounding boxes for the given segmentation mask, which can either be
the full 3D CT scan or a single 2D slice. The dimension of the boxes is the same
as the dimension of the mask. Returns the list of bounding boxes found.
"""
bboxes = []
all_visited = set()
it = np.nditer(mask, flags=['multi_index'])
while not it.finished:
label = it[0]
point = it.multi_index
if label > 0 and point not in all_visited:
bbox, visited = find_bounding_box(mask, point, label)
bboxes.append(bbox)
all_visited |= visited
it.iternext()
return bboxes
def transform(point, height, width):
y, x = point
return [512 * y / height, 512 * x / width]
def generate_boxes_2d(mask):
"""
Generates 2D bounding boxes for each slice of the given 3D CT scan. Returns a map
from each slice index containing at least one bounding box to the list of bounding
boxes in the slice.
"""
bboxes = {}
depth, height, width = mask.shape
for i in range(depth):
slice_bboxes = generate_boxes(mask[i, :, :])
if len(slice_bboxes) > 0:
for j in range(len(slice_bboxes)):
min_point, size = slice_bboxes[j]
slice_bboxes[j] = (transform(min_point, height, width), transform(size, height, width))
bboxes[i] = slice_bboxes
return bboxes
def read_masks(directory):
masks = {}
for filename in glob.glob(os.path.join(directory, "*_seg.nrrd")):
patient = os.path.basename(filename).replace("_seg.nrrd", "")
masks[patient], _ = nrrd.read(filename)
depth, width, height = masks[patient].shape
return masks
if __name__ == "__main__":
sys.setrecursionlimit(8000)
patient_to_bboxes = {}
masks = read_masks("/data/ct-cspine/processed-studies/data_20180524_161757/segmentations/new_raw")
for i, patient in enumerate(masks):
patient_to_bboxes[patient] = generate_boxes_2d(masks[patient])
sys.stdout.write("Generated bounding boxes for {} of {} series.\r".format(i + 1, len(masks)))
sys.stdout.flush()
with open("/data/ct-cspine/processed-studies/data_20180524_161757/new_bounding_boxes.pickle", "wb") as handle:
pkl.dump(patient_to_bboxes, handle, protocol=pkl.HIGHEST_PROTOCOL)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment