Last active
April 27, 2016 15:24
-
-
Save amueller/d5d3de7630a25ae61cff7af4b29b5970 to your computer and use it in GitHub Desktop.
extract colormap from an image
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from colorspacious import cspace_convert | |
from scipy.sparse.csgraph import minimum_spanning_tree | |
from sklearn.metrics import euclidean_distances | |
import scipy.sparse as sp | |
from colorspacious import cspace_convert | |
from scipy.sparse.csgraph import minimum_spanning_tree | |
from sklearn.metrics import euclidean_distances | |
import scipy.sparse as sp | |
def is_heatmap(image, threshold=1, random_seed=None): | |
# drop alpha channel | |
if image.shape[2] == 4: | |
image = image[:, :, :3] | |
# convert to cam space | |
flat = image.reshape(-1, 3) | |
# subsample for speed | |
state = np.random.RandomState(random_seed) | |
indices = state.choice(len(flat), size=2000, replace=False) | |
subsample_rgb = flat[indices] | |
subsample = cspace_convert(subsample_rgb, "sRGB255", "CAM02-UCS") | |
# compute MST | |
distances = euclidean_distances(subsample) | |
mst = sp.csgraph.minimum_spanning_tree(distances) | |
connectivity = (mst + mst.T != 0).astype(np.int) | |
# prune leaves until only a chain is left: | |
last_n_nodes = connectivity.shape[0] | |
central_nodes = np.arange(last_n_nodes) | |
while True: | |
not_leaf, = np.where(np.array(connectivity.sum(axis=0)).ravel() > 1) | |
central_nodes = central_nodes[not_leaf] | |
connectivity = connectivity[not_leaf, :][:, not_leaf] | |
if connectivity.shape[0] >= last_n_nodes - 2: | |
break | |
last_n_nodes = connectivity.shape[0] | |
# central nodes contain the chain | |
chain = subsample[central_nodes] | |
print("length of chain: %d" % len(central_nodes)) | |
if len(central_nodes) < 100: | |
print("probably an illustration / vectorgraphic") | |
# hack to return false on is_heatmap | |
threshold = -1 | |
# now get the order along the chain: | |
# find one of the two leaves | |
leafs, = np.where(np.array(connectivity.sum(axis=0)).ravel() == 1) | |
chain_order, _ = sp.csgraph.depth_first_order(connectivity, leafs[0]) | |
# chain order reorders central_nodes. get indices in original subsample array | |
ordered_chain = central_nodes[chain_order] | |
# check if the chain covers most of the colors: | |
# median was meh (a sunset fooled us), let's try l1 | |
distance_to_chain = np.abs(np.min(euclidean_distances(subsample, chain), axis=1)).mean() | |
print(distance_to_chain) | |
return distance_to_chain < threshold, ordered_chain, subsample, subsample_rgb | |
if __name__ == "__main__": | |
from scipy.misc import imread | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import numpy as np | |
image = imread("./UytvhhN.jpg") | |
decision, chain_order, subsample, subsample_rgb = is_heatmap(image, random_seed=5) | |
from mpl_toolkits.mplot3d import Axes3D | |
# 3d plot in cam space | |
X_ = subsample[chain_order] | |
fig = plt.figure(figsize=(8, 6)) | |
ax = Axes3D(fig, elev=-150, azim=110) | |
ax.scatter(X_[:, 0], X_[:, 1], X_[:, 2], s=100) | |
ax.scatter(subsample[:, 0], subsample[:, 1], subsample[:, 2], c=subsample_rgb/255.) | |
# extracted colormap | |
plt.figure() | |
plt.imshow(np.repeat(subsample_rgb[chain_order].reshape(1, -1, 3), repeats=10, axis=0)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment