Last active
July 1, 2019 10:11
-
-
Save freifrauvonbleifrei/9b24ea1a715e2493d68a0660f8ab180e to your computer and use it in GitHub Desktop.
Normalized Risk-Averting Error Loss in tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#from tensorflow internals | |
def _safe_mean(losses, num_present): | |
"""Computes a safe mean of the losses. | |
Args: | |
losses: `Tensor` whose elements contain individual loss measurements. | |
num_present: The number of measurable elements in `losses`. | |
Returns: | |
A scalar representing the mean of `losses`. If `num_present` is zero, | |
then zero is returned. | |
""" | |
total_loss = math_ops.reduce_sum(losses) | |
return math_ops.div_no_nan(total_loss, num_present, name="value") | |
def _num_elements(losses): | |
"""Computes the number of elements in `losses` tensor.""" | |
with ops.name_scope(None, "num_elements", values=[losses]) as scope: | |
return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) | |
# http://www.math.umbc.edu/~jameslo/papers/isnn12nrae.pdf NRAE | |
def loss_nrae(labels, predictions): | |
# from mean_squared_error definition: | |
predictions = math_ops.cast(predictions, dtype=dtypes.float32) | |
labels = math_ops.cast(labels, dtype=dtypes.float32) | |
predictions.get_shape().assert_is_compatible_with(labels.get_shape()) | |
losses = math_ops.squared_difference(predictions, labels) | |
# now weigh by lambda and exp | |
l_lambda = 1e1 | |
losses = ops.convert_to_tensor(losses) | |
input_dtype = losses.dtype | |
weights = math_ops.cast(l_lambda, dtype=dtypes.float32) | |
weighted_losses = math_ops.multiply(losses, weights) | |
weighted_losses = math_ops.exp(weighted_losses) | |
#reduce by summing | |
# reduction == Reduction.SUM_OVER_BATCH_SIZE: | |
loss = math_ops.reduce_sum(weighted_losses) | |
loss = _safe_mean(weighted_losses, _num_elements(losses)) | |
# re-normalize | |
c = 1/l_lambda * math_ops.log(loss) | |
# Convert the result back to the input type. | |
c = math_ops.cast(c, input_dtype) | |
return c |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment