Skip to content

Instantly share code, notes, and snippets.

@mingrui
Last active October 25, 2022 13:46
Show Gist options
  • Save mingrui/5aa63ca498bbd615f932855c6a6dc724 to your computer and use it in GitHub Desktop.
Save mingrui/5aa63ca498bbd615f932855c6a6dc724 to your computer and use it in GitHub Desktop.
calculate mri nifti segmentation dice score
import nibabel as nib
def test_nifti_all_labels_dice_score(seg_path, seg_file, truth_path, truth_file):
truth_uid = os.listdir(truth_path)
print(truth_uid)
dice_score = 0
for uid in truth_uid:
seg_file_path = os.path.join(seg_path, uid, seg_file)
truth_file_path = os.path.join(truth_path, uid, truth_file)
seg_nib = nib.load(seg_file_path)
seg_data = seg_nib.get_data()
truth_nib= nib.load(truth_file_path)
truth_data = truth_nib.get_data()
uid_dice = calculate_nifti_all_labels_dice_score(seg_data, truth_data)
print(uid_dice)
dice_score+=uid_dice
print('dice score:', dice_score/len(truth_uid))
def calculate_nifti_all_labels_dice_score(seg_data, truth_data):
z_range = range(seg_data.shape[-1])
z_len = len(z_range)
dice_sum = 0
for z in z_range:
seg_slice = seg_data[:,:,z]
truth_slice = truth_data[:,:,z]
slice_dice = calculate_slice_all_labels_dice_score(seg_slice, truth_slice)
dice_sum+=slice_dice
return dice_sum / z_len
def calculate_slice_all_labels_dice_score(segmentation, truth):
area_sum = np.sum(segmentation) + np.sum(truth)
if area_sum > 0:
return np.sum(segmentation[truth>0])*2.0 / area_sum
else:
return 1
def calculate_slice_one_label_dice_score(segmentation, truth, k):
return np.sum(segmentation[truth == k]) * 2.0 / (np.sum(segmentation) + np.sum(truth))
test_nifti_all_labels_dice_score('/mnt/DATA/datasets/tumor-segmentation-pipeline/all-t2/1-3d-unet-preprocessed-nii',
'prediction-mask.nii.gz',
'/mnt/DATA/datasets/tumor-segmentation-pipeline/all-t2/1-3d-unet-preprocessed-nii',
'truth.nii.gz')
@dlabella29
Copy link

should line 36 be 0 instead of 1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment