Skip to content

Instantly share code, notes, and snippets.

@karolzak
Last active January 24, 2019 09:02
Show Gist options
  • Save karolzak/d250d2c029f9d0e5adbd146e57c8cc09 to your computer and use it in GitHub Desktop.
Save karolzak/d250d2c029f9d0e5adbd146e57c8cc09 to your computer and use it in GitHub Desktop.
Image plotting for semantic segmentation data
import numpy as np
import matplotlib.pyplot as plt
def mask_to_red(mask, img_size=1024):
'''
Converts binary segmentation mask from white to red color.
Also adds alpha channel to make black background transparent.
'''
c1 = mask.reshape(img_size,img_size)
c2 = np.zeros((img_size,img_size))
c3 = np.zeros((img_size,img_size))
c4 = mask.reshape(img_size,img_size)
return np.stack((c1, c2, c3, c4), axis=-1)
def mask_to_rgba(mask, img_size=1024, color='red'):
'''
Converts binary segmentation mask from white to red color.
Also adds alpha channel to make black background transparent.
'''
zeros = np.zeros((img_size,img_size))
ones = mask.reshape(img_size,img_size)
if color == 'red':
return np.stack((ones, zeros, zeros, ones), axis=-1)
elif color == 'green':
return np.stack((zeros, ones, zeros, ones), axis=-1)
elif color == 'blue':
return np.stack((zeros, zeros, ones, ones), axis=-1)
elif color == 'yellow':
return np.stack((ones, ones, zeros, ones), axis=-1)
elif color == 'magenta':
return np.stack((ones, zeros, ones, ones), axis=-1)
elif color == 'cyan':
return np.stack((zeros, ones, ones, ones), axis=-1)
def plot_imgs(org_imgs,
mask_imgs,
pred_imgs=None,
nm_img_to_plot=10,
figsize=4,
img_size=1024,
alpha=0.5
):
'''
Image plotting for semantic segmentation data.
Last column is always an overlay of ground truth or prediction
depending on what was provided as arguments.
'''
#nm_img_to_plot = org_imgs.shape[0]
im_id = 0
if not (pred_imgs is None):
cols = 4
else:
cols = 3
fig, axes = plt.subplots(nm_img_to_plot, cols, figsize=(cols*figsize, nm_img_to_plot*figsize))
axes[0, 0].set_title("original", fontsize=15)
axes[0, 1].set_title("ground truth", fontsize=15)
if not (pred_imgs is None):
axes[0, 2].set_title("prediction", fontsize=15)
axes[0, 3].set_title("overlay", fontsize=15)
else:
axes[0, 2].set_title("overlay", fontsize=15)
for m in range(0, nm_img_to_plot):
axes[m, 0].imshow(org_imgs[im_id].reshape((img_size,img_size, 3)))
axes[m, 0].set_axis_off()
axes[m, 1].imshow(mask_imgs[im_id].reshape((img_size,img_size)), cmap='gray')
axes[m, 1].set_axis_off()
if not (pred_imgs is None):
axes[m, 2].imshow(pred_imgs[im_id].reshape((img_size,img_size)), cmap='gray')
axes[m, 2].set_axis_off()
axes[m, 3].imshow(org_imgs[im_id].reshape((img_size,img_size, 3)))
axes[m, 3].imshow(mask_to_red(pred_imgs[im_id].reshape((img_size,img_size)), img_size=img_size), alpha=alpha)
axes[m, 3].set_axis_off()
else:
axes[m, 2].imshow(org_imgs[im_id].reshape((img_size,img_size, 3)))
axes[m, 2].imshow(mask_to_red(mask_imgs[im_id].reshape((img_size,img_size)), img_size=img_size), alpha=alpha)
axes[m, 2].set_axis_off()
im_id += 1
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment