Created
March 27, 2024 11:59
-
-
Save jcboyd/083ea081ab113ba8c972ccb48bf76546 to your computer and use it in GitHub Desktop.
The SLIC superpixel algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from skimage import data | |
from skimage.measure import find_contours | |
from skimage.color import rgb2rgbcie | |
img = data.astronaut() | |
img_cie = rgb2rgbcie(img) | |
H, W = img.shape[:2] | |
xx, yy = np.meshgrid(np.arange(H), np.arange(W)) | |
features = np.dstack([img_cie, xx, yy]) | |
S = 25 | |
center_x = np.linspace(S, W - S, num=W // S).astype('int') | |
center_y = np.linspace(S, H - S, num=H // S).astype('int') | |
xx, yy = np.meshgrid(center_x, center_y) | |
center_coords = np.dstack([xx.flatten(), yy.flatten()])[0] | |
# first pad image to maintain image size | |
padded_cie = np.pad(img_cie, pad_width=[(1, 1), (1, 1), (0, 0)], mode='edge') | |
G_x = padded_cie[1:-1, 2:] - padded_cie[1:-1, :-2] | |
G_y = padded_cie[2:, 1:-1] - padded_cie[:-2, 1:-1] | |
G_xy = np.sqrt(np.sum(G_x ** 2, axis=2) + np.sum(G_y ** 2, axis=2)) | |
for i, (x, y) in enumerate(center_coords): | |
neigh = G_xy[y - 1 : y + 2, x - 1 : x + 2] | |
# find minimising offset | |
dy, dx = np.array(np.unravel_index(neigh.argmin(), neigh.shape)) - 1 | |
center_coords[i] = [x + dx, y + dy] | |
centers = np.array([features[y, x] for x, y in center_coords]) | |
m = 0.1 # compactness - note our features are unnormalised, so we are outside the typical range [1, 20] | |
E = float('inf') | |
tol = 5e-2 | |
while E > tol: | |
superpixels = np.zeros((H, W)) | |
distances = float('inf') * np.ones((H, W)) | |
for center_idx, center_feature in enumerate(centers): | |
# form superpixel window | |
x, y = int(center_feature[3]), int(center_feature[4]) | |
t, b = max(y - S, 0), min(y + S, H) | |
l, r = max(x - S, 0), min(x + S, W) | |
window = features[t:b, l:r] | |
# compute distances | |
d_lab = np.sqrt(np.sum((window[..., :3] - center_feature[:3]) ** 2, axis=2)) | |
d_xy = np.sqrt(np.sum((window[..., 3:] - center_feature[3:]) ** 2, axis=2)) | |
D_s = d_lab + (m / S) * d_xy | |
# reassign classes if shorter distance found | |
classes = superpixels[t:b, l:r] | |
classes[distances[t:b, l:r] > D_s] = center_idx | |
superpixels[t:b, l:r] = classes | |
# update distances | |
distances[t:b, l:r] = np.minimum(distances[t:b, l:r], D_s) | |
# Update superpixel centers | |
new_centers = np.zeros_like(centers) | |
for center_idx in range(len(centers)): | |
superpixel = superpixels == center_idx | |
new_centers[center_idx] = np.mean(features[superpixel], axis=0) | |
# Compute L1 error | |
E = np.mean(np.abs(new_centers - centers)) | |
centers = new_centers | |
segments = superpixels.astype('int') | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(img) | |
for i in range(np.max(segments)): | |
cont = find_contours(segments == i)[0] | |
ax.fill(cont[:, 1], cont[:, 0], facecolor='none', edgecolor='red') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment