Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active August 29, 2023 12:29
Show Gist options
  • Save sadimanna/5ffcda35463dc1e1b521bb9497d973a7 to your computer and use it in GitHub Desktop.
Save sadimanna/5ffcda35463dc1e1b521bb9497d973a7 to your computer and use it in GitHub Desktop.
@tf.keras.saving.register_keras_serializable(name="WeightedBinaryCrossentropy")
class WeightedBinaryCrossentropy:
def __init__(
self,
label_smoothing=0.0,
weights = [1.0, 1.0],
axis=-1,
name="weighted_binary_crossentropy",
fn = None,
):
"""Initializes `WeightedBinaryCrossentropy` instance.
Args:
from_logits: Whether to interpret `y_pred` as a tensor of
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
assume that `y_pred` contains probabilities (i.e., values in [0,
1]).
label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When >
0, we compute the loss between the predicted labels and a smoothed
version of the true labels, where the smoothing squeezes the labels
towards 0.5. Larger values of `label_smoothing` correspond to
heavier smoothing.
axis: The axis along which to compute crossentropy (the features
axis). Defaults to -1.
name: Name for the op. Defaults to 'weighted_binary_crossentropy'.
"""
super().__init__()
self.weights = weights # tf.convert_to_tensor(weights)
self.label_smoothing = label_smoothing
self.name = name
self.fn = weighted_binary_crossentropy if fn is None else fn
def __call__(self, y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
self.label_smoothing = tf.convert_to_tensor(self.label_smoothing, dtype=y_pred.dtype)
def _smooth_labels():
return y_true * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
y_true = tf.__internal__.smart_cond.smart_cond(self.label_smoothing, _smooth_labels, lambda: y_true)
return tf.reduce_mean(self.fn(y_true, y_pred, self.weights),axis=-1)
def get_config(self):
config = {"name": self.name, "weights": self.weights, "fn": self.fn}
# base_config = super().get_config()
return dict(list(config.items()))
@classmethod
def from_config(cls, config):
"""Instantiates a `Loss` from its config (output of `get_config()`).
Args:
config: Output of `get_config()`.
"""
if saving_lib.saving_v3_enabled():
fn_name = config.pop("fn", None)
if fn_name:
config["fn"] = get(fn_name)
return cls(**config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment