Created
July 29, 2019 17:26
-
-
Save donovanr/ad5373101b624934b03b622f6eb07e1c to your computer and use it in GitHub Desktop.
continuous bce snippet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
``` | |
def logC_taylor(x, eps=1e-7, taylor_center=0.5, taylor_radius=0.05): | |
eps = torch.tensor(eps).type_as(x) | |
taylor_center = torch.tensor(taylor_center).type_as(x) | |
taylor_radius = torch.tensor(taylor_radius).type_as(x) | |
# singular at zero and one, so regularize | |
mask = x == 0 | |
x[mask] = x[mask] + eps | |
mask = x == 1 | |
x[mask] = x[mask] - eps | |
# logC = torch.log(2*torch.atanh(1-2*x)/(1-2*x)) | |
# but there's no torch.atanh so we use the alternate form | |
# of arctanh(z) = 1/2 log((1+z)/(1-z)) | |
# ==> arctanh(1-2x) = 1/2 log((1+(1-2x))/(1-(1-2x))) | |
# = 1/2 log((2-2x)/(2x)) | |
# = 1/2 log((1-x)/x) | |
# ==> logC = torch.log(2*torch.atanh(1-2*x)/(1-2*x)) | |
# = torch.log(torch.log((1-x)/x)/(1-2*x)) | |
# = torch.log(torch.log(x/(1-x))/(2*x-1)) | |
logC = torch.log((torch.log(x / (1.0 - x))) / (2.0 * x - 1.0)) | |
# taylor expand around x = 0.5 because of the numerical instability | |
# terms to fourth order are accurate to float precision on the interval [0.45,0.55] | |
def taylor(y, y_0): | |
c_0 = torch.log(torch.tensor(2.0).type_as(y)) | |
c_2 = 4.0 / 3.0 | |
c_4 = 104.0 / 45.0 | |
diff2 = (y - y_0) ** 2 | |
return c_0 + c_2 * diff2 + c_4 * diff2 ** 2 | |
mask = torch.abs(x - taylor_center) < taylor_radius | |
taylor_result = taylor(x[mask], taylor_center) | |
logC[mask] = taylor_result | |
return logC | |
class ContinuousBCELoss(nn.BCELoss): | |
r"""Creates a criterion that is the Continuous BCELoss of | |
The continuous Bernoulli: fixing a pervasive error in variational autoencoders | |
https://arxiv.org/abs/1907.06845v1 | |
""" | |
def __init__(self, **kwargs): | |
super(ContinuousBCELoss, self).__init__(**kwargs) | |
def forward(self, input, target): | |
BCE = super(ContinuousBCELoss, self).forward(input, target) | |
C = logC_taylor(input) | |
if self.reduction == "mean": | |
C = torch.mean(C) | |
elif self.reduction == "sum": | |
C = torch.sum(C) | |
elif self.reduction == "none": | |
pass | |
return BCE - C | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment