Last active
          September 25, 2025 01:54 
        
      - 
      
 - 
        
Save dvdhfnr/732c26b61a0e63a0abc8a5d769dbebd0 to your computer and use it in GitHub Desktop.  
    Loss function of MiDaS
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| import torch.nn as nn | |
| def compute_scale_and_shift(prediction, target, mask): | |
| # system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
| a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
| a_01 = torch.sum(mask * prediction, (1, 2)) | |
| a_11 = torch.sum(mask, (1, 2)) | |
| # right hand side: b = [b_0, b_1] | |
| b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
| b_1 = torch.sum(mask * target, (1, 2)) | |
| # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
| x_0 = torch.zeros_like(b_0) | |
| x_1 = torch.zeros_like(b_1) | |
| det = a_00 * a_11 - a_01 * a_01 | |
| valid = det.nonzero() | |
| x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
| x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
| return x_0, x_1 | |
| def reduction_batch_based(image_loss, M): | |
| # average of all valid pixels of the batch | |
| # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) | |
| divisor = torch.sum(M) | |
| if divisor == 0: | |
| return 0 | |
| else: | |
| return torch.sum(image_loss) / divisor | |
| def reduction_image_based(image_loss, M): | |
| # mean of average of valid pixels of an image | |
| # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) | |
| valid = M.nonzero() | |
| image_loss[valid] = image_loss[valid] / M[valid] | |
| return torch.mean(image_loss) | |
| def mse_loss(prediction, target, mask, reduction=reduction_batch_based): | |
| M = torch.sum(mask, (1, 2)) | |
| res = prediction - target | |
| image_loss = torch.sum(mask * res * res, (1, 2)) | |
| return reduction(image_loss, 2 * M) | |
| def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): | |
| M = torch.sum(mask, (1, 2)) | |
| diff = prediction - target | |
| diff = torch.mul(mask, diff) | |
| grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) | |
| mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) | |
| grad_x = torch.mul(mask_x, grad_x) | |
| grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) | |
| mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) | |
| grad_y = torch.mul(mask_y, grad_y) | |
| image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) | |
| return reduction(image_loss, M) | |
| class MSELoss(nn.Module): | |
| def __init__(self, reduction='batch-based'): | |
| super().__init__() | |
| if reduction == 'batch-based': | |
| self.__reduction = reduction_batch_based | |
| else: | |
| self.__reduction = reduction_image_based | |
| def forward(self, prediction, target, mask): | |
| return mse_loss(prediction, target, mask, reduction=self.__reduction) | |
| class GradientLoss(nn.Module): | |
| def __init__(self, scales=4, reduction='batch-based'): | |
| super().__init__() | |
| if reduction == 'batch-based': | |
| self.__reduction = reduction_batch_based | |
| else: | |
| self.__reduction = reduction_image_based | |
| self.__scales = scales | |
| def forward(self, prediction, target, mask): | |
| total = 0 | |
| for scale in range(self.__scales): | |
| step = pow(2, scale) | |
| total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], | |
| mask[:, ::step, ::step], reduction=self.__reduction) | |
| return total | |
| class ScaleAndShiftInvariantLoss(nn.Module): | |
| def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): | |
| super().__init__() | |
| self.__data_loss = MSELoss(reduction=reduction) | |
| self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) | |
| self.__alpha = alpha | |
| self.__prediction_ssi = None | |
| def forward(self, prediction, target, mask): | |
| scale, shift = compute_scale_and_shift(prediction, target, mask) | |
| self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) | |
| total = self.__data_loss(self.__prediction_ssi, target, mask) | |
| if self.__alpha > 0: | |
| total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) | |
| return total | |
| def __get_prediction_ssi(self): | |
| return self.__prediction_ssi | |
| prediction_ssi = property(__get_prediction_ssi) | 
Isn't it possible for the scale that comes back from compute_scale_and_shift to be negative?
Hi! what do i have to do with the target depths (maybe something i can do to preprocessing the data) before calculating the ssi_loss using the midas output?
Hello @dvdhfnr thank you for the great work.
This is the L_ssimse loss from the paper. In Section 5 they say that they use L_ssitrim for all experiments which, I assume, includes the released models.
For those who want the trim loss, here is my version (unmasked):
def ssi_trim_loss(residual):
    b, _, h, w = residual.shape
    m = h * w
    u_m = int(0.8 * m)
    abs_residual = torch.abs(residual)
    flat_abs_residual = abs_residual.view(b, -1)
    # Get an index of sorted abs_residual
    _, sorted_idx = torch.sort(flat_abs_residual.detach(), dim=1)
    # Get the top 80% of the sorted abs_residual
    top_80_idx = sorted_idx[:, :u_m]
    top_80_abs_residual = torch.gather(flat_abs_residual, 1, top_80_idx)
    # Sum the top 80% of the sorted abs_residual
    sum_top_80_abs_residual = torch.sum(top_80_abs_residual, dim=1)
    # Divide by the number of pixels
    loss_per_batch = sum_top_80_abs_residual / (2 * m)
    # Average over the batch
    loss = torch.mean(loss_per_batch)
    return loss
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
            
Thanks for the share.
One question. Is the target in the disparity space or depth space? Thanks!