-
-
Save lukasz-migas/19986c9be608f69c1f438c4dd15881ff to your computer and use it in GitHub Desktop.
some visualization helper functions (specifically useful for understanding feature spaces, and visualizing clustering in KMeans and GMM's) using common plotting libraries like matplotlib and seaborn,
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
import seaborn as sns | |
from itertools import combinations, cycle, product | |
mpl_palette=['r', 'g', 'b', 'c', 'm','y'] | |
cblind_palette=[(230,159,0),(86,180,233),(0,158,115),(240,228,66),(0,114,178),(213,94,0),(204,121,167)] | |
cblind_palette_norm=[(.9,.6,0),(.35,.70,.90),(0,.60,.5),(.95,.90,.25),(0,.45,.70),(.80,.40,0),(.80,.60,.70)] | |
def histogram_thresholds(hist, bins, thresholds, | |
xlabel='Histogram Counts', | |
ylabel='Intensity Values'): | |
""" Plots the intensity histogram segmented by threshold values | |
args: | |
hist: list of histogram counts, as returned from | |
skimage.exposure.histogram | |
bins: list of bin values, as returned from | |
skimage.exposure.histogram | |
thresholds: list of thresholds, as returned by | |
segmentation.Thresholding.maximal_variance | |
""" | |
fix, ax = plt.subplots(1,1) | |
ax.plot(bins, hist, lw=2) | |
ymax = ax.axis()[3] | |
for i in range(len(thresholds)): | |
ax.plot(np.ones(ymax)*thresholds[i], np.arange(ymax), 'k') | |
ax.set_xlabel(xlabel) | |
ax.set_ylabel(ylabel) | |
def scatter(X_small, centroids, dim_labels, silho_score): | |
""" WARNING: DEPRECIATED. Multiple Scatter plots are better done | |
using the library seaborn. See the jointcentroid | |
functions for better functionality. | |
Plots 2D Scatter Visualization of Clustering Results. | |
Use this function to visualize 2D Projections of 3 and 4 Feature Dimensions | |
Parameters | |
---------- | |
X_small: a numpy array of shape (number of subsamples, number of features), | |
the subsampled Feature Vector. | |
centroids: a numpy array of shape (number of clusters, number of features), | |
the cluster centers | |
dim_labels: a list of string elements, each denoting the (0,1,2...d) | |
feature dimension labels | |
silho_score: a float, the Silhouette Score measuring cluster goodness | |
Returns | |
------- | |
None | |
""" | |
n_samples, n_features = X_small.shape | |
assert n_features == len(dim_labels),"Feature and Label Dimensions don't match" | |
dim_combo=list(combinations(range(n_features),2)) | |
if len(dim_combo)==3: | |
fig, axes = plt.subplots(1,3,figsize=(21,7)) | |
elif len(dim_combo)==6: | |
fig, axes = plt.subplots(2,3,figsize=(14,14)) | |
else: | |
print("This method only supports 3 and 4 feature dimensions") | |
return | |
fig.suptitle('Silhouette Score: {0}'.format(round(silho_score,3))) | |
for ax, (x, y) in zip(axes, dim_combo): | |
ax.hexbin(X_small[:,x], X_small[:,y], cmap=plt.cm.Purples) | |
ax.scatter(centroids[:,x], | |
centroids[:,y], | |
marker='x', color='r') | |
ax.set_xlabel(dim_labels[x]) | |
ax.set_ylabel(dim_labels[y]) | |
ax.set_xlim([np.min(X_small[:,x]),np.max(X_small[:,x])]) | |
ax.set_ylim([np.min(X_small[:,y]),np.max(X_small[:,y])]) | |
def viz_final_segmentation(image, final_segmentation, p_of_p, fig_outfile): | |
""" | |
Shows K corresponding images of the K classes. | |
If K=3, this could be | |
{'unhealthy','borderline','healthy'} | |
Note that no assumption is made about the ordering and | |
corresponding level of healthiness in the crop images | |
Parameters | |
---------- | |
image: array-like, shape (M, N) or (M, N, num_color_channels) | |
the image to which the separate segmentations will | |
be overlayed | |
final_segmentation: a list of (M, N) masked arrays, | |
where each masked array, final_segmentaiton[i], | |
is the image of the segmented regions which | |
belong to cluster/class i | |
p_of_p: an array, length K clusters/classes, | |
where each element of the array denotes the percentage | |
of pixels which belong to cluster/class i | |
fig_outfile: a file or file string to save the output to | |
""" | |
K=len(final_segmentation) | |
fig,ax=plt.subplots(K,1,figsize=(15,15*K)) | |
for i in range(len(ax)): | |
ax[i].imshow(image, alpha=0.6) | |
ax[i].imshow(final_segmentation[i], cmap=plt.cm.RdYlGn, vmin=-1.0, | |
vmax=1.0) | |
ax[i].set_title('% of Field: {0}'.format(round(p_of_p[i], 4)), | |
fontsize=30) | |
ax[i].set_xticks([]); ax[i].set_yticks([]) | |
plt.tight_layout() | |
plt.savefig(fig_outfile, bbox_inches='tight') | |
def multi_checkbox_overlay(image, final_segmentation, p_of_p, checkbox, res_dir, | |
description): | |
""" | |
Parameters | |
---------- | |
image: array-like, shape (M, N) or (M, N, num_color_channels) | |
the image to which the separate segmentations will | |
be overlayed | |
final_segmentation: a list of (M, N) masked arrays, | |
where each masked array, final_segmentaiton[i], | |
is the image of the segmented regions which | |
belong to cluster/class i | |
p_of_p: an array, length K clusters/classes, | |
where each element of the array denotes the percentage | |
of pixels which belong to cluster/class i | |
checkbox: a boolean array, of length K, | |
denoting which segmented regions of final_segmentation | |
should be overlayed onto the image | |
res_dir: a string, | |
path to the result directory to save the image | |
description: a string, | |
any add on to the file description for this final_segmentation | |
""" | |
K=len(final_segmentation) | |
fig,ax=plt.subplots(figsize=(15,15)) | |
ax.imshow(image, alpha=0.6) | |
p_of_p_sum = 0 | |
for i in range(K): | |
if checkbox[i] is True: | |
ax.imshow(final_segmentation[i], cmap=plt.cm.RdYlGn, vmin=-1.0, vmax=1.0) | |
p_of_p_sum += p_of_p[i] | |
ax.set_title('% of Field: {0}'.format(round(p_of_p_sum, 4)), | |
fontsize=30) | |
ax.set_xticks([]); ax.set_yticks([]) | |
plt.tight_layout() | |
plt.savefig(res_dir+'multi_checkbox_overlay_{0}'.format(description),bbox_inches='tight') | |
def joint2D(X, labels): | |
""" | |
Parameters | |
---------- | |
X: array-like, shape: (n_samples, 2) | |
data. with 2 features | |
labels: list of strings, length 2 | |
label of each column in X | |
""" | |
n_samples, n_features = X.shape | |
if n_features == 2: | |
df_X = pd.DataFrame(X, columns=labels) | |
# Plot a 2D Projection Plots | |
g = sns.jointplot(x=labels[0], | |
y=labels[1], | |
data=df_X, | |
kind='hex', | |
size=10) | |
elif n_features > 2: | |
df_X = pd.DataFrame(X, columns=labels) | |
# Plot a 2D Projection Plots and Save Them as Tmps | |
combos = list(combinations(range(n_features),2)) | |
for i, (x,y) in enumerate(combos): | |
g = sns.jointplot(x=labels[x], | |
y=labels[y], | |
data=df_X, | |
kind='hex', | |
size=10) | |
plt.savefig('/Temp/f%d.png' %i, bbox_inches='tight') | |
plt.close() | |
# Combine them with imshows | |
n_ax=len(combos) | |
fig, ax = plt.subplots(nrows=n_ax, ncols=1, figsize=(10,10*n_ax)) | |
for i in range(n_ax): | |
ax[i].imshow(plt.imread('/Temp/f%s.png' %i)); ax[i].axis('off') | |
# plt.tight_layout(); plt.savefig(fig_outfile,bbox_inches='tight'); plt.close() dpi = 200 | |
def colored_scatter(X, Y_, x, y, ax, palette=cblind_palette_norm, alpha=0.025): | |
""" | |
Plots a scatter plot with each points color corresponding to the | |
predicted cluster | |
Parameters | |
---------- | |
X: array-like, shape: (n_samples, n_features) | |
data array | |
Y_: array-like, shape: (n_samples,) | |
elements are integers, representing the predicted class labels | |
outputted by a sklearn estimator, i.e. KMeans or GMM | |
x: integer | |
denoting the index of the feature we treat as the x-axis | |
y: integer | |
denoting the index of the feature we treat as the y-axis | |
ax: matplotlib axis handle | |
the axis on which the colored scatter shall be generated | |
palette: array, optional | |
each element is a string, rgb tuple (range 0-1), or hexidecimal | |
denoting a specific color. This palette will be cycled through | |
to visualize the respective clusters. Highly recommended to leave | |
this argument as default, to ensure that the same color palette is | |
used for the ellipse plotting and to make sure that color-blind- | |
friendly colors are used. | |
alpha: float, range (0-1) | |
alpha transparency level for plotting | |
Returns | |
------- | |
axis: matplotlib axis handle | |
the changed axis handle now containing the colored scatter points | |
""" | |
color_iter=cycle(palette) | |
for i,color in zip(np.unique(Y_),color_iter): | |
ax.scatter(X[Y_==i,x],X[Y_==i,y],color=color,alpha=alpha) | |
return ax | |
def viz_colored_segmentation(image, final_seg, fig_outfile): | |
"""Y_: array-like, shape: (n_samples,) | |
elements are integers, representing the predicted class labels | |
outputted by a sklearn estimator, i.e. KMeans or GMM | |
image: array-like, shape (M, N) or (M, N, num_color_channels) | |
the image to which the separate segmentations will | |
be overlayed | |
final_seg: a list of (M, N) masked arrays, | |
where each masked array, final_segmentaiton[i], | |
is the image of the segmented regions which | |
belong to cluster/class i | |
fig_outfile: a file or file string to save the output to""" | |
maskColor = zip(final_seg, cycle(cblind_palette_norm)) | |
final_segmentation = [] | |
for x in maskColor: | |
mask = np.dstack((np.ma.getmask(x[0]), np.ma.getmask(x[0]), np.ma.getmask(x[0]))) | |
data = np.ma.getdata(x[0]) | |
colored = np.ma.array(np.dstack([np.ones(data.shape)*x[1][i] for i in range(3)]), mask = mask) | |
final_segmentation.append(colored) | |
K=len(final_segmentation) | |
fig,ax=plt.subplots(K,1,figsize=(15,15*K)) | |
for i in range(len(ax)): | |
ax[i].imshow(np.ma.filled(final_segmentation[i])) | |
ax[i].imshow(image, alpha=0.6) | |
ax[i].set_xticks([]); ax[i].set_yticks([]) | |
plt.tight_layout() | |
plt.savefig(fig_outfile, bbox_inches='tight') | |
class CentroidCluster: | |
@staticmethod | |
def jointcentroid(X, clf, dim_labels, fig_outfile, alpha=0.33): | |
""" Plots 2D scatter visualization of clustering results | |
Use this function to visualize 2D Projections of 3 and 4 Feature Dimensions | |
Parameters | |
---------- | |
X: array-like, shape: (n_samples, n_features) | |
data array | |
clf: sklearn estimator | |
Must be an estimator which has attribute "cluster_centers_" (i.e KMeans) | |
dim_labels: a list of string elements, each denoting the (0,1,2...d) | |
feature dimension labels. | |
fig_outfile: a file or filestring to save the output to | |
alpha: float, optional | |
the transparency of the scatter points | |
""" | |
n_samples, n_features = X.shape | |
assert n_features == len(dim_labels),"Feature and Label Dimensions don't match" | |
df_X = pd.DataFrame(X, columns=dim_labels) | |
size=15 | |
if n_features == 2: | |
# Plot 2D Projection Plots and save them as temp images | |
g=sns.JointGrid(dim_labels[0],dim_labels[1],df_X,size=size) | |
g.plot_marginals(sns.distplot) | |
g.ax_joint = colored_scatter(X, clf.predict(X), 0, 1, g.ax_joint, | |
alpha=alpha) | |
g.ax_joint=CentroidCluster._colored_centers(clf,0,1,g.ax_joint) | |
elif n_features > 2: | |
dim_combo=list(combinations(range(n_features),2)) | |
for i, (x, y) in enumerate(dim_combo): | |
# Plot 2D Projection Plots and save them as temp images | |
g=sns.JointGrid(dim_labels[x],dim_labels[y],df_X,size=size) | |
g.plot_marginals(sns.distplot) | |
g.ax_joint = colored_scatter(X, clf.predict(X), x, y, | |
g.ax_joint, alpha=alpha) | |
g.ax_joint=CentroidCluster._colored_centers(clf,x,y,g.ax_joint) | |
plt.savefig('/Temp/f%d.png' %i, bbox_inches='tight') | |
plt.close() | |
# Combine them with imshows | |
n_ax=len(dim_combo) | |
fig, ax = plt.subplots(n_ax, 1, figsize=(size,size*n_ax)) | |
for i in range(n_ax): | |
ax[i].imshow(plt.imread('/Temp/f%s.png' %i)); ax[i].axis('off') | |
plt.tight_layout(); plt.savefig(fig_outfile,bbox_inches='tight') # dpi = 200 | |
@staticmethod | |
def _colored_centers(clf, x, y, ax, palette=cblind_palette_norm, alpha=0.7): | |
""" | |
Helper function that draws cluster centers defined by the | |
cluster_centers_ attribute of any centroid clustering estimator | |
Parameters | |
---------- | |
clf: an sklearn estimator | |
a sklearn.mixture.GMM estimator which has already been | |
trained/fitted to the data | |
x: integer | |
denoting the index of the feature we treat as the x-axis | |
y: integer | |
denoting the index of the feature we treat as the y-axis | |
ax: matplotlib axes | |
axes to plot ellipses to | |
palette: array, optional | |
each element is a string, rgb tuple (range 0-1), or hexidecimal | |
denoting a specific color. This palette will be cycled through | |
to visualize the respective clusters. Highly recommended to leave | |
this argument as default, to ensure that the same color palette is | |
used for the ellipse plotting and to make sure that color-blind- | |
friendly colors are used. | |
alpha: float, range (0-1) | |
alpha transparency level for plotting | |
Returns | |
------- | |
ax: matplotlib axes | |
the axes which now contain the ellipses | |
""" | |
color_iter=cycle(palette) | |
for center,color in zip(clf.cluster_centers_,color_iter): | |
ax.plot(center[x], center[y], 'o', markerfacecolor='k', | |
markeredgecolor='k', markersize=10, alpha=alpha, lw=2) | |
return ax | |
class GaussianMM: | |
@staticmethod | |
def proba_dist_1D(vals, probas): | |
""" | |
Plots the probability distribution to visualize the shapes | |
of the Gaussians that determine membership in each of the | |
classes/clusters. | |
Parameters | |
---------- | |
vals: 1-D array, length: num_samples | |
the values for each data point in the single feature dimension | |
probas: array-like, shape: (n_samples, n_components) | |
the probability that the sample data point belongs to | |
each component, or so-called Gaussian | |
""" | |
n_samples, n_components= probas.shape | |
fig, ax = plt.subplots(figsize=(8,6)) | |
plt.hold('on') | |
for i in range(n_components): | |
x, y = zip(*sorted(zip(vals, probas[:,i]), | |
key=lambda t: t[0])) | |
ax.plot(x, y, label='Component {0}'.format(i)) | |
handles, labels = ax.get_legend_handles_labels() | |
# labels = ['Component {0}'.format(i) for i in range(n_components)] | |
ax.legend(handles, labels, loc=2) | |
@staticmethod | |
def jointellipse(X, clf, dim_labels, fig_outfile, alpha): | |
""" Plots 2D hex visualization of clustering results | |
Use this function to visualize 2D Projections of a GMM clustering. | |
Parameters | |
---------- | |
X: a numpy array of shape (number of subsamples, number of features), | |
the subsampled Feature Vector. | |
clf: sklearn estimator | |
a mixture model estimator which has already trained on the data | |
dim_labels: a list of strings, | |
each denoting the (0,1,2...d) feature dimension labels. | |
fig_outfile: file object or string | |
the file or filestring to save the image to | |
alpha: float | |
the transparency of the scatter points | |
""" | |
n_samples, n_features = X.shape | |
assert n_features == len(dim_labels),"Feature and Label Dimensions don't match" | |
df_X = pd.DataFrame(X, columns=dim_labels) | |
size=15 | |
if n_features == 2: | |
# Plot 2D Projection Plots and save them as temp images | |
g=sns.JointGrid(dim_labels[0],dim_labels[1],df_X,size=size) | |
g.plot_marginals(sns.distplot) | |
g.ax_joint = colored_scatter(X, clf.predict(X), 0, 1, g.ax_joint, | |
alpha=alpha) | |
g.ax_joint=GaussianMM._draw_ellipses(clf,0,1,g.ax_joint) | |
elif n_features > 2: | |
dim_combo=list(combinations(range(n_features),2)) | |
for i, (x, y) in enumerate(dim_combo): | |
# Plot 2D Projection Plots and save them as temp images | |
g=sns.JointGrid(dim_labels[x],dim_labels[y],df_X,size=size) | |
g.plot_marginals(sns.distplot) | |
g.ax_joint = colored_scatter(X, clf.predict(X), x, y, | |
g.ax_joint, alpha=alpha) | |
g.ax_joint=GaussianMM._draw_ellipses(clf,x,y,g.ax_joint) | |
plt.savefig('/Temp/f%d.png' %i, bbox_inches='tight') | |
plt.close() | |
# Combine them with imshows | |
n_ax=len(dim_combo) | |
fig, ax = plt.subplots(n_ax, 1, figsize=(size,size*n_ax)) | |
for i in range(n_ax): | |
ax[i].imshow(plt.imread('/Temp/f%s.png' %i)); ax[i].axis('off') | |
plt.tight_layout(); plt.savefig(fig_outfile,bbox_inches='tight') # dpi = 200 | |
@staticmethod | |
def _draw_ellipses(clf, x, y, ax, palette=cblind_palette_norm, alpha=0.8): | |
""" | |
Helper function that draws ellipses based on the eigenstuff of the | |
covariance matrix of a Gaussian Mixture Model | |
Parameters | |
---------- | |
clf: an sklearn estimator | |
a sklearn.mixture.GMM estimator which has already been | |
trained/fitted to the data | |
x: integer | |
denoting the index of the feature we treat as the x-axis | |
y: integer | |
denoting the index of the feature we treat as the y-axis | |
ax: matplotlib axes | |
axes to plot ellipses to | |
palette: array, optional | |
each element is a string, rgb tuple (range 0-1), or hexidecimal | |
denoting a specific color. This palette will be cycled through | |
to visualize the respective clusters. Highly recommended to leave | |
this argument as default, to ensure that the same color palette is | |
used for the ellipse plotting and to make sure that color-blind- | |
friendly colors are used. | |
alpha: float, range (0-1) | |
alpha transparency level for plotting | |
Returns | |
------- | |
ax: matplotlib axes | |
the axes which now contain the ellipses | |
""" | |
color_iter=cycle(palette) | |
for (mean,covar,color) in zip(clf.means_, | |
clf._get_covars(), | |
color_iter): | |
Cxy = np.reshape([covar[i,j] for (i,j) in product([x,y], repeat=2)], | |
(2,2)) # operate on the 2D covar matrix for this xy projection | |
w, v = np.linalg.eigh(Cxy) # w.shape:(2,),v.shape:(2,2); thus,x=0,y=1 | |
u = v[0] # eigvects of covar matrix are symmetric and unit norm | |
angle = np.arctan(u[1] / u[0]) | |
angle = 180 * angle / np.pi # convert to degrees | |
w = np.sqrt(w) # standard deviation better than variance for viz | |
ell = mpl.patches.Ellipse((mean[x],mean[y]), w[0], w[1], 180+angle) | |
ell.set_clip_box(ax.bbox) | |
ell.set_facecolor(color) | |
ell.set_edgecolor('k') | |
ell.set_lw(1) | |
ell.set_alpha(alpha) | |
ax.add_artist(ell) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment