Last active
February 9, 2018 15:21
-
-
Save edraizen/d29b2a3f46e40e0de945ba098604b942 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import groupby | |
from torch.nn.modules.loss import _Loss | |
class DiceLoss(_Loss): | |
def __init__(self, size_average=True, smooth=1.): | |
super(DiceLoss, self).__init__(size_average) | |
self.smooth = smooth | |
def forward(self, input, target, locations, weights=None): | |
if self.size_average: | |
return -self.dice_coef_samples(input, target, locations, weights) | |
return -self.dice_coef_batch(input, target, weights) | |
def dice_coef_batch(self, input, target, weights=None): | |
iflat = input.view(-1) | |
tflat = target.view(-1) | |
intersection = (iflat * tflat).sum() | |
dice = ((2. * intersection + self.smooth) / ((iflat.sum() + tflat.sum() + self.smooth))) | |
if weights is not None: | |
dice *= weights | |
return dice | |
def dice_coef_samples(self, input, target, locations, weights=None): | |
samples = locations[:, 3] | |
previous_row = 0 | |
dice = None | |
num_samples = samples[-1]+1 | |
if weight is not None: | |
use_sample_weights = isinstance(weight, (list, tuple)) | |
if use_sample_weights: | |
assert use_sample_weights and len(weight) == num_samples | |
for i, sample in groupby(enumerate(samples), key=lambda x:x[1]): | |
for voxel_end in sample: pass | |
batch_predictions = input[previous_row:voxel_end[0]+1] | |
target_values = target[previous_row:voxel_end[0]+1] | |
previous_row = voxel_end[0] | |
iflat = batch_predictions.view(-1) | |
tflat = target_values.view(-1) | |
intersection = (iflat * tflat).sum() | |
dice_val = ((2. * intersection + self.smooth) / ((iflat.sum() + tflat.sum() + self.smooth))) | |
if use_sample_weights: | |
dice_val *= weights[i] | |
if dice is None: | |
dice = dice_val | |
else: | |
dice += dice_val | |
if weights is not None and not use_sample_weights: | |
dice_val *= weights | |
return dice/float(num_samples) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment