Skip to content

Instantly share code, notes, and snippets.

@jcboyd
Created March 27, 2024 11:59
Show Gist options
  • Save jcboyd/083ea081ab113ba8c972ccb48bf76546 to your computer and use it in GitHub Desktop.
Save jcboyd/083ea081ab113ba8c972ccb48bf76546 to your computer and use it in GitHub Desktop.
The SLIC superpixel algorithm
import matplotlib.pyplot as plt
import numpy as np
from skimage import data
from skimage.measure import find_contours
from skimage.color import rgb2rgbcie
img = data.astronaut()
img_cie = rgb2rgbcie(img)
H, W = img.shape[:2]
xx, yy = np.meshgrid(np.arange(H), np.arange(W))
features = np.dstack([img_cie, xx, yy])
S = 25
center_x = np.linspace(S, W - S, num=W // S).astype('int')
center_y = np.linspace(S, H - S, num=H // S).astype('int')
xx, yy = np.meshgrid(center_x, center_y)
center_coords = np.dstack([xx.flatten(), yy.flatten()])[0]
# first pad image to maintain image size
padded_cie = np.pad(img_cie, pad_width=[(1, 1), (1, 1), (0, 0)], mode='edge')
G_x = padded_cie[1:-1, 2:] - padded_cie[1:-1, :-2]
G_y = padded_cie[2:, 1:-1] - padded_cie[:-2, 1:-1]
G_xy = np.sqrt(np.sum(G_x ** 2, axis=2) + np.sum(G_y ** 2, axis=2))
for i, (x, y) in enumerate(center_coords):
neigh = G_xy[y - 1 : y + 2, x - 1 : x + 2]
# find minimising offset
dy, dx = np.array(np.unravel_index(neigh.argmin(), neigh.shape)) - 1
center_coords[i] = [x + dx, y + dy]
centers = np.array([features[y, x] for x, y in center_coords])
m = 0.1 # compactness - note our features are unnormalised, so we are outside the typical range [1, 20]
E = float('inf')
tol = 5e-2
while E > tol:
superpixels = np.zeros((H, W))
distances = float('inf') * np.ones((H, W))
for center_idx, center_feature in enumerate(centers):
# form superpixel window
x, y = int(center_feature[3]), int(center_feature[4])
t, b = max(y - S, 0), min(y + S, H)
l, r = max(x - S, 0), min(x + S, W)
window = features[t:b, l:r]
# compute distances
d_lab = np.sqrt(np.sum((window[..., :3] - center_feature[:3]) ** 2, axis=2))
d_xy = np.sqrt(np.sum((window[..., 3:] - center_feature[3:]) ** 2, axis=2))
D_s = d_lab + (m / S) * d_xy
# reassign classes if shorter distance found
classes = superpixels[t:b, l:r]
classes[distances[t:b, l:r] > D_s] = center_idx
superpixels[t:b, l:r] = classes
# update distances
distances[t:b, l:r] = np.minimum(distances[t:b, l:r], D_s)
# Update superpixel centers
new_centers = np.zeros_like(centers)
for center_idx in range(len(centers)):
superpixel = superpixels == center_idx
new_centers[center_idx] = np.mean(features[superpixel], axis=0)
# Compute L1 error
E = np.mean(np.abs(new_centers - centers))
centers = new_centers
segments = superpixels.astype('int')
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(img)
for i in range(np.max(segments)):
cont = find_contours(segments == i)[0]
ax.fill(cont[:, 1], cont[:, 0], facecolor='none', edgecolor='red')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment