Created
October 22, 2020 11:16
-
-
Save burnpiro/4aa07d6e031ce0b7f2c55f8c759a9ae8 to your computer and use it in GitHub Desktop.
List of helpers to generate images for tensorboard
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_confusion_matrix(cm, class_names=class_names): | |
""" | |
Returns a matplotlib figure containing the plotted confusion matrix. | |
Args: | |
cm (array, shape = [n, n]): a confusion matrix of integer classes | |
class_names (array, shape = [n]): String names of the integer classes | |
""" | |
figure = plt.figure(figsize=(8, 8)) | |
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
plt.title("Confusion matrix") | |
plt.colorbar() | |
tick_marks = np.arange(len(class_names)) | |
plt.xticks(tick_marks, class_names, rotation=45) | |
plt.yticks(tick_marks, class_names) | |
# Compute the labels from the normalized confusion matrix. | |
labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2) | |
# Use white text if squares are dark; otherwise black. | |
threshold = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
color = "white" if cm[i, j] > threshold else "black" | |
plt.text(j, i, labels[i, j], horizontalalignment="center", color=color) | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') | |
return figure | |
def plot_to_image(figure): | |
"""Converts the matplotlib plot specified by 'figure' to a PNG image and | |
returns it. The supplied figure is closed and inaccessible after this call.""" | |
# Save the plot to a PNG in memory. | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
# Closing the figure prevents it from being displayed directly inside | |
# the notebook. | |
plt.close(figure) | |
buf.seek(0) | |
# Convert PNG buffer to TF image | |
image = tf.image.decode_png(buf.getvalue(), channels=4) | |
# Add the batch dimension | |
image = tf.expand_dims(image, 0) | |
return image | |
def image_grid(labels, preds, miss_class, class_names, images): | |
# Create a figure to contain the plot. | |
figure = plt.figure(figsize=(10,3*len(miss_class))) | |
i = 0 | |
for idx in miss_class: | |
# Start next subplot. | |
label=f"Predicted: {class_names[preds[idx]]}, Acc label: {class_names[labels[idx]]}" | |
plt.subplot(len(miss_class), 2, i + 1, title=label) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.grid(False) | |
img = tf.cast(tf.reshape(images[idx], [IMAGE_WIDTH, IMAGE_WIDTH, 3])*255, tf.uint8) | |
plt.imshow(img, cmap=plt.cm.binary) | |
i += 1 | |
return figure |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment