Created
April 23, 2023 14:09
-
-
Save nicoandmee/b5e4b1fcc840c001f49f8401222918bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
validation_images = [batch["image"] for batch in validation_dataset] # type: ignore | |
validation_labels = [batch["label"] for batch in validation_dataset] # type: ignore | |
def calculate_edit_distance(labels, predictions): | |
# Get a single batch and convert its labels to sparse tensors. | |
saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64) | |
# Make predictions and convert them to sparse tensors. | |
input_len = np.ones(predictions.shape[0]) * predictions.shape[1] | |
predictions_decoded = tf.keras.backend.ctc_decode( | |
predictions, input_length=input_len, greedy=True | |
)[0][0][:, :max_length] | |
sparse_predictions = tf.cast( | |
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64 | |
) | |
# Compute individual edit distances and average them out. | |
edit_distances = tf.edit_distance( | |
sparse_predictions, saprse_labels, normalize=False | |
) | |
return tf.reduce_mean(edit_distances) | |
class EditDistanceCallback(tf.keras.callbacks.Callback): | |
def __init__(self, pred_model): | |
super().__init__() | |
self.prediction_model = pred_model | |
def on_epoch_end(self, epoch, logs=None): | |
edit_distances = [] | |
# convert to list comprehension | |
for i in range(len(validation_images)): | |
labels = validation_labels[i] | |
edit_distances.append(calculate_edit_distance(labels, self.prediction_model.predict(validation_images[i])).numpy()) | |
print( | |
f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment