Created
December 4, 2020 09:18
-
-
Save syaffers/36ad13402c0fe20ac7437f28c8907c9c to your computer and use it in GitHub Desktop.
K-means reduced palette class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class KMeansReducedPalette: | |
def __init__(self, num_colors): | |
self.num_colors = num_colors | |
# Random state for reproducibility. | |
self.kmeans = KMeans(num_colors, random_state=0xfee1600d) | |
self.source_pixels = None | |
def _preprocess(self, image): | |
assert image.shape[-1] == 3, 'image must have exactly 3 color channels' | |
assert image.dtype == 'uint8', 'image must be in np.uint8 type' | |
# Flatten pixels, if not already. | |
if len(image.shape) > 2: | |
return image.reshape(-1, 3) | |
return image | |
def fit(self, image): | |
image_cpy = image.copy() | |
self.source_pixels = self._preprocess(image_cpy) | |
self.kmeans.fit(self.source_pixels) | |
def recolor(self, image): | |
original_shape = image.shape | |
image = self._preprocess(image) | |
recolor_idx = self.kmeans.predict(image) | |
recolor = self.kmeans.cluster_centers_[recolor_idx] | |
recolor = recolor.reshape(original_shape) | |
return np.round(recolor).astype(np.uint8) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment