Last active
April 7, 2025 13:31
-
-
Save ranftlr/1d6194db2e1dffa0a50c9b0a9549cbd2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def trimmed_mae_loss(prediction, target, mask, trim=0.2): | |
M = torch.sum(mask, (1, 2)) | |
res = prediction - target | |
res = res[mask.bool()].abs() | |
trimmed, _ = torch.sort(res.view(-1), descending=False)[ | |
: int(len(res) * (1.0 - trim)) | |
] | |
return trimmed.sum() / (2 * M.sum()) | |
def reduction_batch_based(image_loss, M): | |
# average of all valid pixels of the batch | |
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) | |
divisor = torch.sum(M) | |
if divisor == 0: | |
return 0 | |
else: | |
return torch.sum(image_loss) / divisor | |
def normalize_prediction_robust(target, mask): | |
ssum = torch.sum(mask, (1, 2)) | |
valid = ssum > 0 | |
m = torch.zeros_like(ssum) | |
s = torch.ones_like(ssum) | |
m[valid] = torch.median( | |
(mask[valid] * target[valid]).view(valid.sum(), -1), dim=1 | |
).values | |
target = target - m.view(-1, 1, 1) | |
sq = torch.sum(mask * target.abs(), (1, 2)) | |
s[valid] = torch.clamp((sq[valid] / ssum[valid]), min=1e-6) | |
return target / (s.view(-1, 1, 1)) | |
class TrimmedProcrustesLoss(nn.Module): | |
def __init__(self, alpha=0.5, scales=4, reduction="batch-based"): | |
super(TrimmedProcrustesLoss, self).__init__() | |
self.__data_loss = TrimmedMAELoss(reduction=reduction) | |
self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) | |
self.__alpha = alpha | |
self.__prediction_ssi = None | |
def forward(self, prediction, target, mask): | |
self.__prediction_ssi = normalize_prediction_robust(prediction, mask) | |
target_ = normalize_prediction_robust(target, mask) | |
total = self.__data_loss(self.__prediction_ssi, target_, mask) | |
if self.__alpha > 0: | |
total += self.__alpha * self.__regularization_loss( | |
self.__prediction_ssi, target_, mask | |
) | |
return total | |
def __get_prediction_ssi(self): | |
return self.__prediction_ssi | |
prediction_ssi = property(__get_prediction_ssi) | |
class GradientLoss(nn.Module): | |
def __init__(self, scales=4, reduction='batch-based'): | |
super().__init__() | |
if reduction == 'batch-based': | |
self.__reduction = reduction_batch_based | |
else: | |
self.__reduction = reduction_image_based | |
self.__scales = scales | |
def forward(self, prediction, target, mask): | |
total = 0 | |
for scale in range(self.__scales): | |
step = pow(2, scale) | |
total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], | |
mask[:, ::step, ::step], reduction=self.__reduction) | |
return total |
I don't see TrimmedMAELoss
declared anywhere used on L45.
Is it just trimmed_mae_loss?
Hello,I want to konw the meaning of " M = torch.sum(mask, (1, 2))" , is that "mask" is a tensor with shape [n,c,h,w], the c is equal 1, this code is to change the tensor shape to [n,h,w]?
Hi @wch1996 , I think the mask and target should be of shape [n, h, w] instead of [n, 1, h, w]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
As far as my understanding goes, a mask is anything used to specify valid depth (or disparity) values. For example, the mask could be the location where depth values aren't 0, i.e.
mask = (depth != 0)
.I hope this helps!