Skip to content

Instantly share code, notes, and snippets.

@ankitshekhawat
Last active May 9, 2019 19:28
Show Gist options
  • Save ankitshekhawat/1e97db819e27c3edba992d840acf60ef to your computer and use it in GitHub Desktop.
Save ankitshekhawat/1e97db819e27c3edba992d840acf60ef to your computer and use it in GitHub Desktop.
Clip a binary segmentation map to super pixel using IOU
from scipy.ndimage.morphology import binary_erosion
from keras.utils import to_categorical
from fast_slic.avx2 import SlicAvx2
def clip_to_superpixel(image, z, num_segments=100, erosion=12, threshold=0.5, iou_thresh=0.5):
fast_slic = SlicAvx2(num_components=num_segments, compactness=10, quantize_level=1)
segments = fast_slic.iterate(image)
onehot = to_categorical(segments, num_classes=num_segments, dtype='uint8')
threshold = (np.max(z)+np.min(z))*threshold
intersection = np.sum(np.logical_and(np.reshape(z>threshold, (-1,1)), onehot.astype('bool').reshape(-1, num_segments)), axis=0) +1
union = np.sum(np.sum(onehot, axis=0), axis=0) +1
iou = intersection/union
iou_thresh = (max(iou)+min(iou))*iou_thresh
print(np.reshape(onehot, (-1, num_segments)).shape, (iou>iou_thresh).shape, np.logical_and(np.reshape(onehot, (-1, num_segments)), iou>iou_thresh).shape)
msk = np.any(np.logical_and(np.reshape(onehot, (-1, num_segments)), iou>iou_thresh), axis=1).reshape((image.shape[0],image.shape[1]))
msk = binary_erosion(msk, iterations=erosion)
return msk, segments
def clip_to_superpixel_batch(self, images, z, n_segments=100, threshold=0.5, iou_thresh=0.5):
batch_size = len(images)
w = images.shape[1]
h = images.shape[2]
fast_slic = SlicAvx2(num_components=n_segments, compactness=10, quantize_level=1)
segments = np.array([fast_slic.iterate(image) for image in images])
onehot = to_categorical(segments, num_classes=n_segments, dtype='uint8')
onehot = onehot.astype('bool').reshape(batch_size, -1, num_segments)
thresholds= ( np.max(np.max(z, axis=1), axis=1) + np.min(np.min(z, axis=1), axis=1) )*threshold
thresh_mask = np.expand_dims(np.greater(z.reshape(batch_size,-1),np.repeat(np.expand_dims(thresholds,1), w*h, axis=1)),2)
intersection = np.sum(np.logical_and(onehot, thresh_mask), axis=1) +1
union = np.sum(onehot, axis=1) +1
iou = intersection/union
iou_thresh = (np.max(iou, axis=1) + np.min(iou, axis=1))*iou_thresh
iou_thresh_mask = np.greater(iou, np.repeat(np.expand_dims(iou_thresh, 1), n_segments, axis=1))
msk = np.any(np.logical_and(onehot, np.expand_dims(iou_thresh_mask, 1)), axis=2).reshape(-1, w,h)
return msk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment