Skip to content

Instantly share code, notes, and snippets.

@syaffers
Created December 4, 2020 09:18
Show Gist options
  • Save syaffers/36ad13402c0fe20ac7437f28c8907c9c to your computer and use it in GitHub Desktop.
Save syaffers/36ad13402c0fe20ac7437f28c8907c9c to your computer and use it in GitHub Desktop.
K-means reduced palette class
class KMeansReducedPalette:
def __init__(self, num_colors):
self.num_colors = num_colors
# Random state for reproducibility.
self.kmeans = KMeans(num_colors, random_state=0xfee1600d)
self.source_pixels = None
def _preprocess(self, image):
assert image.shape[-1] == 3, 'image must have exactly 3 color channels'
assert image.dtype == 'uint8', 'image must be in np.uint8 type'
# Flatten pixels, if not already.
if len(image.shape) > 2:
return image.reshape(-1, 3)
return image
def fit(self, image):
image_cpy = image.copy()
self.source_pixels = self._preprocess(image_cpy)
self.kmeans.fit(self.source_pixels)
def recolor(self, image):
original_shape = image.shape
image = self._preprocess(image)
recolor_idx = self.kmeans.predict(image)
recolor = self.kmeans.cluster_centers_[recolor_idx]
recolor = recolor.reshape(original_shape)
return np.round(recolor).astype(np.uint8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment