Skip to content

Instantly share code, notes, and snippets.

@sarthakpati
Created May 15, 2023 21:13
Show Gist options
  • Save sarthakpati/bbef55285960fb523bfc48729b5a3777 to your computer and use it in GitHub Desktop.
Save sarthakpati/bbef55285960fb523bfc48729b5a3777 to your computer and use it in GitHub Desktop.
Easy way to perform one-hot encoding of a medical annotation
from typing import Union, OrderedDict
import SimpleITK as sitk
import numpy as np
def one_hot_encode(
input_mask: Union[str, sitk.Image],
encoding_logic: dict = {"NET": [1], "TC": [1, 4], "WT": [1, 2, 4]},
) -> dict:
"""
This function one-hot encodes the input mask according to the encoding logic.
Args:
input_mask (Union[str, sitk.Image]): The input mask.
encoding_logic (_type_, optional): The encoding logic. Defaults to the BraTS region definition of {"NET": [1], "TC": [1, 4], "WT": [1, 2, 4]}.
Returns:
dict: The output masks with the same keys as in the encoding_logic
"""
# read in the image
if isinstance(input_mask, str):
input_mask = sitk.ReadImage(input_mask)
# get the array from image
input_mask_array = sitk.GetArrayFromImage(input_mask)
output_masks = {}
for key, value in encoding_logic.items():
# create a new zero mask
current_mask = np.zeros_like(input_mask_array)
# add all labels to the mask
for label in value:
current_mask += (input_mask_array == label).astype(current_mask.dtype)
# convert to sitk image and copy information
output_masks[key] = sitk.GetImageFromArray(current_mask)
output_masks[key].CopyInformation(input_mask)
return output_masks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment