Last active
August 29, 2023 12:29
-
-
Save sadimanna/5ffcda35463dc1e1b521bb9497d973a7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@tf.keras.saving.register_keras_serializable(name="WeightedBinaryCrossentropy") | |
class WeightedBinaryCrossentropy: | |
def __init__( | |
self, | |
label_smoothing=0.0, | |
weights = [1.0, 1.0], | |
axis=-1, | |
name="weighted_binary_crossentropy", | |
fn = None, | |
): | |
"""Initializes `WeightedBinaryCrossentropy` instance. | |
Args: | |
from_logits: Whether to interpret `y_pred` as a tensor of | |
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we | |
assume that `y_pred` contains probabilities (i.e., values in [0, | |
1]). | |
label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > | |
0, we compute the loss between the predicted labels and a smoothed | |
version of the true labels, where the smoothing squeezes the labels | |
towards 0.5. Larger values of `label_smoothing` correspond to | |
heavier smoothing. | |
axis: The axis along which to compute crossentropy (the features | |
axis). Defaults to -1. | |
name: Name for the op. Defaults to 'weighted_binary_crossentropy'. | |
""" | |
super().__init__() | |
self.weights = weights # tf.convert_to_tensor(weights) | |
self.label_smoothing = label_smoothing | |
self.name = name | |
self.fn = weighted_binary_crossentropy if fn is None else fn | |
def __call__(self, y_true, y_pred): | |
y_pred = tf.convert_to_tensor(y_pred) | |
y_true = tf.cast(y_true, y_pred.dtype) | |
self.label_smoothing = tf.convert_to_tensor(self.label_smoothing, dtype=y_pred.dtype) | |
def _smooth_labels(): | |
return y_true * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing | |
y_true = tf.__internal__.smart_cond.smart_cond(self.label_smoothing, _smooth_labels, lambda: y_true) | |
return tf.reduce_mean(self.fn(y_true, y_pred, self.weights),axis=-1) | |
def get_config(self): | |
config = {"name": self.name, "weights": self.weights, "fn": self.fn} | |
# base_config = super().get_config() | |
return dict(list(config.items())) | |
@classmethod | |
def from_config(cls, config): | |
"""Instantiates a `Loss` from its config (output of `get_config()`). | |
Args: | |
config: Output of `get_config()`. | |
""" | |
if saving_lib.saving_v3_enabled(): | |
fn_name = config.pop("fn", None) | |
if fn_name: | |
config["fn"] = get(fn_name) | |
return cls(**config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment