Last active
August 11, 2020 10:31
-
-
Save royshil/52bd3a0e21e9e7f6a8747bde52a3f86b to your computer and use it in GitHub Desktop.
Segmentation based on graph-cut and SLIC superpixels. See more details: http://www.morethantechnical.com/2017/10/30/revisiting-graph-cut-segmentation-with-slic-and-color-histograms-wpython/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from skimage.segmentation import slic | |
from skimage.segmentation import mark_boundaries | |
from skimage.data import astronaut | |
from skimage.util import img_as_float | |
import maxflow | |
from scipy.spatial import Delaunay | |
# Calculate the SLIC superpixels, their histograms and neighbors | |
def superpixels_histograms_neighbors(img): | |
# SLIC | |
segments = slic(img, n_segments=500, compactness=20) | |
segments_ids = np.unique(segments) | |
# centers | |
centers = np.array([np.mean(np.nonzero(segments==i),axis=1) for i in segments_ids]) | |
# H-S histograms for all superpixels | |
hsv = cv2.cvtColor(img.astype('float32'), cv2.COLOR_BGR2HSV) | |
bins = [20, 20] # H = S = 20 | |
ranges = [0, 360, 0, 1] # H: [0, 360], S: [0, 1] | |
colors_hists = np.float32([cv2.calcHist([hsv],[0, 1], np.uint8(segments==i), bins, ranges).flatten() for i in segments_ids]) | |
# neighbors via Delaunay tesselation | |
tri = Delaunay(centers) | |
return (centers,colors_hists,segments,tri.vertex_neighbor_vertices) | |
# Get superpixels IDs for FG and BG from marking | |
def find_superpixels_under_marking(marking, superpixels): | |
fg_segments = np.unique(superpixels[marking[:,:,0]!=255]) | |
bg_segments = np.unique(superpixels[marking[:,:,2]!=255]) | |
return (fg_segments, bg_segments) | |
# Sum up the histograms for a given selection of superpixel IDs, normalize | |
def cumulative_histogram_for_superpixels(ids, histograms): | |
h = np.sum(histograms[ids],axis=0) | |
return h / h.sum() | |
# Get a bool mask of the pixels for a given selection of superpixel IDs | |
def pixels_for_segment_selection(superpixels_labels, selection): | |
pixels_mask = np.where(np.isin(superpixels_labels, selection), True, False) | |
return pixels_mask | |
# Get a normalized version of the given histograms (divide by sum) | |
def normalize_histograms(histograms): | |
return np.float32([h / h.sum() for h in histograms]) | |
# Perform graph cut using superpixels histograms | |
def do_graph_cut(fgbg_hists, fgbg_superpixels, norm_hists, neighbors): | |
num_nodes = norm_hists.shape[0] | |
# Create a graph of N nodes, and estimate of 5 edges per node | |
g = maxflow.Graph[float](num_nodes, num_nodes * 5) | |
# Add N nodes | |
nodes = g.add_nodes(num_nodes) | |
hist_comp_alg = cv2.HISTCMP_KL_DIV | |
# Smoothness term: cost between neighbors | |
indptr,indices = neighbors | |
for i in range(len(indptr)-1): | |
N = indices[indptr[i]:indptr[i+1]] # list of neighbor superpixels | |
hi = norm_hists[i] # histogram for center | |
for n in N: | |
if (n < 0) or (n > num_nodes): | |
continue | |
# Create two edges (forwards and backwards) with capacities based on | |
# histogram matching | |
hn = norm_hists[n] # histogram for neighbor | |
g.add_edge(nodes[i], nodes[n], 20-cv2.compareHist(hi, hn, hist_comp_alg), | |
20-cv2.compareHist(hn, hi, hist_comp_alg)) | |
# Match term: cost to FG/BG | |
for i,h in enumerate(norm_hists): | |
if i in fgbg_superpixels[0]: | |
g.add_tedge(nodes[i], 0, 1000) # FG - set high cost to BG | |
elif i in fgbg_superpixels[1]: | |
g.add_tedge(nodes[i], 1000, 0) # BG - set high cost to FG | |
else: | |
g.add_tedge(nodes[i], cv2.compareHist(fgbg_hists[0], h, hist_comp_alg), | |
cv2.compareHist(fgbg_hists[1], h, hist_comp_alg)) | |
g.maxflow() | |
return g.get_grid_segments(nodes) | |
if __name__ == '__main__': | |
img = img_as_float(astronaut()[::2, ::2]) | |
img_marking = cv2.imread("astronaut_marking.png") | |
centers, colors_hists, segments, neighbors = superpixels_histograms_neighbors(img) | |
fg_segments, bg_segments = find_superpixels_under_marking(img_marking, segments) | |
# get cumulative BG/FG histograms, before normalization | |
fg_cumulative_hist = cumulative_histogram_for_superpixels(fg_segments, colors_hists) | |
bg_cumulative_hist = cumulative_histogram_for_superpixels(bg_segments, colors_hists) | |
norm_hists = normalize_histograms(colors_hists) | |
graph_cut = do_graph_cut((fg_cumulative_hist, bg_cumulative_hist), | |
(fg_segments, bg_segments), | |
norm_hists, | |
neighbors) | |
plt.subplot(1,2,2), plt.xticks([]), plt.yticks([]) | |
plt.title('segmentation') | |
segmask = pixels_for_segment_selection(segments, np.nonzero(graph_cut)) | |
cv2.imwrite("output_segmentation.png", np.uint8(segmask * 255)) | |
plt.imshow(segmask) | |
plt.subplot(1,2,1), plt.xticks([]), plt.yticks([]) | |
img = mark_boundaries(img, segments) | |
img[img_marking[:,:,0]!=255] = (1,0,0) | |
img[img_marking[:,:,2]!=255] = (0,0,1) | |
plt.imshow(img) | |
plt.title("SLIC + markings") | |
plt.savefig("segmentation.png",bbox_inches='tight',dpi=96) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment