Skip to content

Instantly share code, notes, and snippets.

@aiwithshekhar
Created December 13, 2019 19:35
Show Gist options
  • Save aiwithshekhar/34e04b5651d40744a22b5b770b9f19a9 to your computer and use it in GitHub Desktop.
Save aiwithshekhar/34e04b5651d40744a22b5b770b9f19a9 to your computer and use it in GitHub Desktop.
calculate dice scores
'''calculates dice scores when Scores class for it'''
def dice_score(pred, targs):
pred = (pred>0).float()
return 2. * (pred*targs).sum() / (pred+targs).sum()
''' initialize a empty list when Scores is called, append the list with dice scores
for every batch, at the end of epoch calculates mean of the dice scores'''
class Scores:
def __init__(self, phase, epoch):
self.base_dice_scores = []
def update(self, targets, outputs):
probs = outputs
dice= dice_score(probs, targets)
self.base_dice_scores.append(dice)
def get_metrics(self):
dice = np.mean(self.base_dice_scores)
return dice
'''return dice score for epoch when called'''
def epoch_log(epoch_loss, measure):
'''logging the metrics at the end of an epoch'''
dices= measure.get_metrics()
dice= dices
print("Loss: %0.4f |dice: %0.4f" % (epoch_loss, dice))
return dice
@akshatgarg99
Copy link

in dice_score() you didnt squre the elements n the denominator. The dice score becomes dimentionally wrong

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment