Last active
August 29, 2015 14:22
-
-
Save darthsuogles/b4a6f14f92e06c3d9075 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Recursively reload depedencies | |
| # Ref: https://ipython.org/ipython-doc/dev/interactive/reference.html#dreload | |
| from IPython.lib.deepreload import reload as dreload | |
| # %load_ext autoreload | |
| # %autoreload 2 | |
| import numpy as np | |
| import scipy.io as sio | |
| import json # export data for d3 visualization | |
| from mlearn_hclust import hierarchical_clustering | |
| import matplotlib.pyplot as plt | |
| from matplotlib import gridspec | |
| import networkx as nx | |
| #from sklearn.cluster import AgglomerativeClustering | |
| from scipy.stats import itemfreq | |
| from sklearn.cluster import spectral_clustering | |
| from sklearn import manifold | |
| # Compute K-nearest neighbor based on the distance matrix | |
| def classify_Wasserstein_KNN(wasserstein, kls, K = 2): | |
| """ | |
| \param wasserstein: an NxN matrix | |
| \param kls: a list of N items, each is a string representing the class label | |
| """ | |
| kls_uniq = np.unique(kls) | |
| num_kls = kls_uniq.shape[0] | |
| kls_inds = dict(zip(kls_uniq, range(num_kls))) | |
| confusion = np.zeros((num_kls, num_kls), dtype = int) | |
| N, N = wasserstein.shape | |
| cnt_tot = 0 | |
| cnt_correct = 0 | |
| for i in xrange(N): | |
| curr_inds = np.argsort( wasserstein[i,:] ) | |
| predict_kls = sorted( itemfreq( kls[ curr_inds[1:K] ] ), | |
| key = lambda elem: int(elem[1]), reverse = True)[0][0] | |
| cnt_tot += 1 | |
| if predict_kls == kls[i]: | |
| cnt_correct += 1 | |
| confusion[ kls_inds[kls[i]], kls_inds[predict_kls] ] += 1 | |
| print('success', cnt_correct, 'total', cnt_tot) | |
| print('confusion matrix') | |
| print(confusion) | |
| for i, label in enumerate(kls_uniq): | |
| miss_inds = confusion[i] > 0 | |
| print label, '\t->', zip(kls_uniq[ miss_inds ], confusion[i][ miss_inds ] ) | |
| ## Perform spectral clustering | |
| # http://scikit-learn.org/stable/modules/generated/sklearn.cluster.spectral_clustering.html | |
| def spectral_clustering(affinity, num_clusters = 5): | |
| #affinity = exp(-Dst / np.median(Dst) ) | |
| clustering_labels = spectral_clustering(affinity, n_clusters = num_clusters) | |
| # Store the result in a networkx structure | |
| G = nx.Graph() | |
| G.add_nodes_from(np.arange(N)) | |
| for u, lab_u in enumerate(clustering_labels): | |
| for v, lab_v in enumerate(clustering_labels[(u+1):]): | |
| if lab_u == lab_v: | |
| G.add_edge(u, v, weight = affinity[u,v]) | |
| # Plot the minimum spanning tree | |
| T = nx.minimum_spanning_tree(G) | |
| kls_idx = dict(zip(np.unique(kls), np.arange(M))) | |
| node_colors = [kls_idx[label] for label in kls] | |
| embedding_coords = nx.graphviz_layout(T) | |
| nx.draw_networkx_nodes(T, embedding_coords, node_size = 3, | |
| node_color = node_colors) | |
| nx.draw_networkx_edges(T, embedding_coords, style = 'solid', alpha = 0.2) | |
| plt.axis('off') | |
| # Export the tree to json for D3js | |
| json_data = { | |
| "nodes": [ | |
| {"name": u, "group": int(clustering_labels[u]), | |
| "class_id": u_grp, "class_name": kls[u]} for u, u_grp in enumerate(node_colors)], | |
| "links": [ | |
| {"source": u, "target": v, "value": affinity[u,v]} for u, v in T.edges()] | |
| } | |
| json.dump(json_data, open("kimia_spectral_clustering.json", "w")) | |
| # Scale and visualize the embedding vectors | |
| def plot_embedding(X, kls, title=None): | |
| x_min, x_max = np.min(X, 0), np.max(X, 0) | |
| X = (X - x_min) / (x_max - x_min) | |
| kls_unique = np.unique(kls) | |
| kls_idx = dict(zip(kls_unique, np.arange(M, dtype = float))) | |
| node_colors = [kls_idx[label] for label in kls] | |
| plt.figure() | |
| ax = plt.subplot(111) | |
| for i in range(X.shape[0]): | |
| plt.text(X[i, 0], X[i, 1], kls[i], | |
| color=plt.cm.Set1(kls_idx[ kls[i] ] / len(kls_unique)), | |
| fontdict={'weight': 'bold', 'size': 7}) | |
| plt.xticks([]), plt.yticks([]) | |
| if title is not None: | |
| plt.title(title) | |
| ## Compare a bunch of dimension reduction algorithms | |
| def dimension_reduction(Dst, kls_inds, num_neighbors = 10): | |
| fig, ax_list = plt.subplots(2, 4, sharex = 'none', sharey = '', figsize = (15, 8)) | |
| plt.suptitle("Dimension Reduction and Euclidean Embedding") | |
| lle_methods = ['standard', 'ltsa', 'hessian', 'modified'] | |
| lle_labels = ['LLE', 'LTSA', 'Hessian LLE', 'Modified LLE'] | |
| for i, method, in enumerate(lle_methods): | |
| lle = manifold.LocallyLinearEmbedding(n_neighbors = num_neighbors, n_components = 2, | |
| eigen_solver='auto', method = method) | |
| Y = lle.fit_transform(Dst) | |
| ax = ax_list[0,i] | |
| ax.scatter(Y[:, 0], Y[:, 1], c = kls_inds, cmap = plt.cm.Spectral) | |
| ax.set_title(lle_labels[i]) | |
| ax.xaxis.set_major_formatter(NullFormatter()) | |
| ax.yaxis.set_major_formatter(NullFormatter()) | |
| ax.axis('tight') | |
| ## Second row | |
| methods = [ manifold.Isomap(n_neighbors = num_neighbors, n_components = 2), | |
| manifold.MDS(n_components = 2, max_iter=100, n_init=1), | |
| manifold.SpectralEmbedding(n_components = 2, n_neighbors = num_neighb10), | |
| manifold.TSNE(n_components = 2, metric = 'precomputed', perplexity = 5) ] | |
| labels = [ 'Isomap', 'MDS', 'Spectral', 't-SNE' ] | |
| for i, algo in enumerate(methods): | |
| Y = algo.fit_transform(Dst) | |
| ax = ax_list[1,i] | |
| ax.scatter(Y[:, 0], Y[:, 1], c = kls_inds, cmap = plt.cm.Spectral) | |
| ax.set_title(labels[i]) | |
| ax.xaxis.set_major_formatter(NullFormatter()) | |
| ax.yaxis.set_major_formatter(NullFormatter()) | |
| ax.axis('tight') | |
| def show_dataset(): | |
| ## Plot a subset of the data | |
| rng = np.arange(N) | |
| kls_unique = np.unique(kls) | |
| inds = np.zeros(len(kls_unique), dtype = int) | |
| for i, label in enumerate(kls_unique): | |
| inds[i] = rng[ label == kls ][0] | |
| #fig, ax_matrix = plt.subplots(len(inds), 3, sharex = 'col') | |
| plt.close('all') | |
| fig = plt.figure() | |
| gs = gridspec.GridSpec(len(inds), 5, | |
| wspace = 0.4, hspace = 0.2) | |
| #gs = gridspec.GridSpec(len(inds), 5) | |
| #fig, ax_matrix = plt.subplots(len(inds), 3, sharex = 'col') | |
| ax1_ref, ax2_ref = None, None | |
| for row, idx in enumerate(inds): | |
| ax = plt.subplot(gs[row, 0]) | |
| ax.plot( contours[idx][:,0], contours[idx][:,1] ) | |
| ax.set_xticks([]); ax.set_yticks([]) | |
| ax.set_ylabel(kls_unique[row], | |
| rotation = 'horizontal', horizontalalignment = 'right') | |
| #ax.set(aspect = 1) | |
| if ax1_ref: | |
| ax = plt.subplot(gs[row, 1:3], sharex = ax1_ref) | |
| else: | |
| ax1_ref = ax = plt.subplot(gs[row, 1:3]) | |
| tsq = torseq[idx].ravel() | |
| ax.plot( np.linspace(0, 1, len(tsq)), tsq ) | |
| #ax.set_title( 'torque sequence' ) | |
| ax.get_yaxis().set_visible(False) | |
| if row + 1 < len(inds): | |
| ax.get_xaxis().set_visible(False) | |
| if ax2_ref: | |
| ax = plt.subplot(gs[row, 3:], sharex = ax2_ref) | |
| else: | |
| ax2_ref = ax = plt.subplot(gs[row, 3:]) | |
| _, bin_locs, _ = ax.hist( tsq, bins = 30 ) | |
| ax.axvline( x = 0, linewidth = 2, color = 'r' ) | |
| #ax.set_title( 'torque histogram' ) | |
| ax.get_yaxis().set_visible(False) | |
| if row + 1 < len(inds): | |
| ax.get_xaxis().set_visible(False) | |
| fig.savefig('kls_contour_torseq_distr.pdf') | |
| #gs.tight_layout(fig) | |
| ## Perform classification with KNN | |
| #classify_Wasserstein_KNN(wasserstein, kls, K = 2) | |
| ## Load the data | |
| spdb_kimia = ShapeDB_Kimia() | |
| contours, torseq, kls, wasserstein, riemannian, N, M = spdb_kimia.load_data(); | |
| kls_unique = np.unique(kls) | |
| map_kls2idx = dict( zip( kls_unique, np.arange(len(kls_unique)) ) ) | |
| kls_inds = np.asarray( [ map_kls2idx[label] for label in kls ] ) | |
| #dimension_reduction(riemannian, 10) | |
| hierarchical_clustering(riemannian, kls, title = 'riemannian') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment