Skip to content

Instantly share code, notes, and snippets.

@perimosocordiae
Created August 4, 2014 18:22
Show Gist options
  • Save perimosocordiae/9e63c304286a18a33d7d to your computer and use it in GitHub Desktop.
Save perimosocordiae/9e63c304286a18a33d7d to your computer and use it in GitHub Desktop.
MST graph construction
import numpy as np
from scipy.sparse.csgraph import minimum_spanning_tree
from sklearn.metrics import pairwise_distances
def perturbed_mst(X, num_perturbations=20, jitter=None):
'''Builds a graph as the union of several MSTs on perturbed data.
Reference: http://ecovision.mit.edu/~sloop/shao.pdf, page 8
jitter refers to the scale of the gaussian noise added for each perturbation.
When jitter is None, it defaults to the 5th percentile interpoint distance.'''
D = pairwise_distances(X, metric='l2')
if jitter is None:
jitter = np.percentile(D[D>0], 5)
W = minimum_spanning_tree(D)
W.data[:] = 1.0 # binarize
for i in xrange(num_perturbations):
pX = X + np.random.normal(scale=jitter, size=X.shape)
pW = minimum_spanning_tree(pairwise_distances(pX, metric='l2'))
pW.data[:] = 1.0
W = W + pW
# final graph is the average over all pertubed MSTs + the original
W.data /= (num_perturbations + 1.0)
return W
def disjoint_mst(X, num_spanning_trees=3):
'''Builds a graph as the union of several spanning trees,
each time removing any edges present in previously-built trees.
Reference: http://ecovision.mit.edu/~sloop/shao.pdf, page 9.'''
D = pairwise_distances(X, metric='l2')
mst = minimum_spanning_tree(D)
W = mst.copy()
for i in xrange(1, num_spanning_trees):
ii,jj = mst.nonzero()
D[ii,jj] = np.inf
D[jj,ii] = np.inf
mst = minimum_spanning_tree(D)
W = W + mst
return W
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment