Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active January 12, 2021 17:12
Show Gist options
  • Save crowsonkb/74781ccde9b89debd2188d2818c108f4 to your computer and use it in GitHub Desktop.
Save crowsonkb/74781ccde9b89debd2188d2818c108f4 to your computer and use it in GitHub Desktop.
from torch import nn
class Lambda(nn.Module):
"""Wraps a callable in an :class:`nn.Module` without registering it."""
def __init__(self, func):
super().__init__()
object.__setattr__(self, 'forward', func)
def extra_repr(self):
return repr(self.forward)
class WeightedLoss(nn.ModuleList):
"""A weighted combination of multiple loss functions."""
def __init__(self, losses, weights):
super().__init__()
for loss in losses:
self.append(loss if isinstance(loss, nn.Module) else Lambda(loss))
self.weights = weights
def forward(self, *args, **kwargs):
losses = []
for loss, weight in zip(self, self.weights):
losses.append(loss(*args, **kwargs) * weight)
return sum(losses)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment