Created
April 10, 2017 16:00
-
-
Save dwf/b2e1d8d575cb9e7365f302c90d909893 to your computer and use it in GitHub Desktop.
Stable binary cross-entropy, operating on logit predictions instead of relying on Theano to correctly optimize the sigmoid output.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from theano import tensor | |
def binary_crossentropy_from_logits(logits, targets): | |
"""Binary cross-entropy computed from model logits. | |
Parameters | |
---------- | |
predictions : TensorVariable | |
The unnormalized log probabilities of a probabilistic binary | |
classifier. | |
targets : TensorVariable | |
The targets for the classifier in [0, 1]. | |
Returns | |
------- | |
TensorVariable | |
The log probability of each prediction under the dirac | |
density specified by the corresponding target. | |
""" | |
a, t = logits, targets | |
return t * tensor.nnet.softplus(-a) + (1 - t) * tensor.nnet.softplus(a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment