Created
May 1, 2019 17:21
-
-
Save sshleifer/77314de11dba384d8a55acb10254c323 to your computer and use it in GitHub Desktop.
Script for going from mask to bboxes (bbox branch)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import pickle as pkl | |
import nrrd | |
import glob | |
import os | |
import sys | |
def find_bounding_box(mask, point, label): | |
visited = set() | |
min_point = list(point) | |
max_point = list(point) | |
find_bounding_box_helper(mask, point, label, visited, min_point, max_point) | |
size = tuple(max_point[i] - min_point[i] for i in range(len(point))) | |
return (min_point, size), visited | |
def find_bounding_box_helper(mask, point, label, visited, min_point, max_point): | |
if mask[point] == label and point not in visited: | |
visited.add(point) | |
for axis in range(len(point)): | |
min_point[axis] = min(min_point[axis], point[axis]) | |
max_point[axis] = max(max_point[axis], point[axis]) | |
find_bounding_box_helper(mask, point[:axis] + (point[axis] + 1,) + point[axis + 1:], label, visited, min_point, max_point) | |
find_bounding_box_helper(mask, point[:axis] + (point[axis] - 1,) + point[axis + 1:], label, visited, min_point, max_point) | |
def generate_boxes(mask): | |
""" | |
Generates bounding boxes for the given segmentation mask, which can either be | |
the full 3D CT scan or a single 2D slice. The dimension of the boxes is the same | |
as the dimension of the mask. Returns the list of bounding boxes found. | |
""" | |
bboxes = [] | |
all_visited = set() | |
it = np.nditer(mask, flags=['multi_index']) | |
while not it.finished: | |
label = it[0] | |
point = it.multi_index | |
if label > 0 and point not in all_visited: | |
bbox, visited = find_bounding_box(mask, point, label) | |
bboxes.append(bbox) | |
all_visited |= visited | |
it.iternext() | |
return bboxes | |
def transform(point, height, width): | |
y, x = point | |
return [512 * y / height, 512 * x / width] | |
def generate_boxes_2d(mask): | |
""" | |
Generates 2D bounding boxes for each slice of the given 3D CT scan. Returns a map | |
from each slice index containing at least one bounding box to the list of bounding | |
boxes in the slice. | |
""" | |
bboxes = {} | |
depth, height, width = mask.shape | |
for i in range(depth): | |
slice_bboxes = generate_boxes(mask[i, :, :]) | |
if len(slice_bboxes) > 0: | |
for j in range(len(slice_bboxes)): | |
min_point, size = slice_bboxes[j] | |
slice_bboxes[j] = (transform(min_point, height, width), transform(size, height, width)) | |
bboxes[i] = slice_bboxes | |
return bboxes | |
def read_masks(directory): | |
masks = {} | |
for filename in glob.glob(os.path.join(directory, "*_seg.nrrd")): | |
patient = os.path.basename(filename).replace("_seg.nrrd", "") | |
masks[patient], _ = nrrd.read(filename) | |
depth, width, height = masks[patient].shape | |
return masks | |
if __name__ == "__main__": | |
sys.setrecursionlimit(8000) | |
patient_to_bboxes = {} | |
masks = read_masks("/data/ct-cspine/processed-studies/data_20180524_161757/segmentations/new_raw") | |
for i, patient in enumerate(masks): | |
patient_to_bboxes[patient] = generate_boxes_2d(masks[patient]) | |
sys.stdout.write("Generated bounding boxes for {} of {} series.\r".format(i + 1, len(masks))) | |
sys.stdout.flush() | |
with open("/data/ct-cspine/processed-studies/data_20180524_161757/new_bounding_boxes.pickle", "wb") as handle: | |
pkl.dump(patient_to_bboxes, handle, protocol=pkl.HIGHEST_PROTOCOL) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment