Created
November 1, 2021 10:08
-
-
Save kashif/4bf65f5a9cd726718b6ec709f4013fec to your computer and use it in GitHub Desktop.
pt-keras-metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import torch | |
from torchmetrics import Metric | |
def tf2pt(x_tf=None): | |
if x_tf is None: | |
return None | |
x_torch = torch.utils.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(x_tf)) | |
return x_torch | |
def pt2tf(x_torch=None): | |
if x_torch is None: | |
return None | |
x_torch = x_torch.contiguous() | |
x_tf = tf.experimental.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch)) | |
return x_tf | |
class BatchRecall(Metric): | |
def _get_batch_similarities(self, batch_label, full_vocab_similarities): | |
return tf.gather(full_vocab_similarities, tf.transpose(batch_label)[0], axis=1) | |
def __init__( | |
self, | |
thresholds=None, | |
top_k=1, | |
class_id=None, | |
name="batch_recall", | |
): | |
super().__init__() | |
self.metric = tf.keras.metrics.Recall( | |
thresholds=thresholds, top_k=top_k, class_id=class_id, name=name | |
) | |
def update(self, y_true, y_pred, sample_weight=None): | |
tf_y_true = pt2tf(y_true.unsqueeze(-1)) | |
tf_y_pred = pt2tf(y_pred) | |
tf_sample_weight = pt2tf(sample_weight) | |
label_indicides = tf.eye( | |
tf.shape(tf_y_pred)[0], tf.shape(tf_y_pred)[0], dtype=tf.dtypes.float32 | |
) | |
normalized_logits = tf.nn.softmax( | |
self._get_batch_similarities(tf_y_true, tf_y_pred) | |
) | |
self.metric.update_state(label_indicides, normalized_logits, tf_sample_weight) | |
def compute(self): | |
return tf2pt(self.metric.result()) | |
def reset(self): | |
self.metric.reset_state() | |
return super().reset() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment