Skip to content

Instantly share code, notes, and snippets.

@donovanr
Created July 29, 2019 17:26
Show Gist options
  • Save donovanr/ad5373101b624934b03b622f6eb07e1c to your computer and use it in GitHub Desktop.
Save donovanr/ad5373101b624934b03b622f6eb07e1c to your computer and use it in GitHub Desktop.
continuous bce snippet
```
def logC_taylor(x, eps=1e-7, taylor_center=0.5, taylor_radius=0.05):
eps = torch.tensor(eps).type_as(x)
taylor_center = torch.tensor(taylor_center).type_as(x)
taylor_radius = torch.tensor(taylor_radius).type_as(x)
# singular at zero and one, so regularize
mask = x == 0
x[mask] = x[mask] + eps
mask = x == 1
x[mask] = x[mask] - eps
# logC = torch.log(2*torch.atanh(1-2*x)/(1-2*x))
# but there's no torch.atanh so we use the alternate form
# of arctanh(z) = 1/2 log((1+z)/(1-z))
# ==> arctanh(1-2x) = 1/2 log((1+(1-2x))/(1-(1-2x)))
# = 1/2 log((2-2x)/(2x))
# = 1/2 log((1-x)/x)
# ==> logC = torch.log(2*torch.atanh(1-2*x)/(1-2*x))
# = torch.log(torch.log((1-x)/x)/(1-2*x))
# = torch.log(torch.log(x/(1-x))/(2*x-1))
logC = torch.log((torch.log(x / (1.0 - x))) / (2.0 * x - 1.0))
# taylor expand around x = 0.5 because of the numerical instability
# terms to fourth order are accurate to float precision on the interval [0.45,0.55]
def taylor(y, y_0):
c_0 = torch.log(torch.tensor(2.0).type_as(y))
c_2 = 4.0 / 3.0
c_4 = 104.0 / 45.0
diff2 = (y - y_0) ** 2
return c_0 + c_2 * diff2 + c_4 * diff2 ** 2
mask = torch.abs(x - taylor_center) < taylor_radius
taylor_result = taylor(x[mask], taylor_center)
logC[mask] = taylor_result
return logC
class ContinuousBCELoss(nn.BCELoss):
r"""Creates a criterion that is the Continuous BCELoss of
The continuous Bernoulli: fixing a pervasive error in variational autoencoders
https://arxiv.org/abs/1907.06845v1
"""
def __init__(self, **kwargs):
super(ContinuousBCELoss, self).__init__(**kwargs)
def forward(self, input, target):
BCE = super(ContinuousBCELoss, self).forward(input, target)
C = logC_taylor(input)
if self.reduction == "mean":
C = torch.mean(C)
elif self.reduction == "sum":
C = torch.sum(C)
elif self.reduction == "none":
pass
return BCE - C
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment