Skip to content

Instantly share code, notes, and snippets.

@RyanZurrin
Created May 9, 2024 19:10
Show Gist options
  • Save RyanZurrin/a74f31d79ebcda838d7427a48ed84382 to your computer and use it in GitHub Desktop.
Save RyanZurrin/a74f31d79ebcda838d7427a48ed84382 to your computer and use it in GitHub Desktop.
dice coeff
import nibabel as nib
import numpy as np
import csv
import sys
from tqdm import tqdm
def dice_coefficient(mask1, mask2):
"""
Compute the Dice coefficient, a measure of set similarity.
Parameters
----------
mask1, mask2 : array-like, bool
Any arrays of the same shape.
Returns
-------
dice : float
Dice coefficient as a float on range [0,1].
Maximum similarity = 1
No similarity = 0
"""
intersection = np.logical_and(mask1, mask2)
return 2. * intersection.sum() / (mask1.sum() + mask2.sum())
def process_pair(pair):
mask1_path, mask2_path = pair
# Read the masks into numpy arrays
mask1_nii = nib.load(mask1_path)
mask2_nii = nib.load(mask2_path)
mask1_data = mask1_nii.get_fdata()
mask2_data = mask2_nii.get_fdata()
# Ensure the masks are boolean
mask1_data = mask1_data.astype(bool)
mask2_data = mask2_data.astype(bool)
# Calculate the Dice coefficient
dice = dice_coefficient(mask1_data, mask2_data)
return mask1_path, mask2_path, dice
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python dice_coeff.py <csv_file_with_paths>")
sys.exit(1)
csv_file_path = sys.argv[1]
with open(csv_file_path, 'r') as file:
reader = csv.reader(file)
for row in tqdm(reader, desc="Processing pairs"):
mask1_path, mask2_path = row
_, _, dice_score = process_pair((mask1_path, mask2_path))
print("Dice coefficient between\n{}\nand\n{}:\n{:.4f}\n".format(mask1_path, mask2_path, dice_score))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment