Created
May 3, 2011 17:56
-
-
Save davidandrzej/953839 to your computer and use it in GitHub Desktop.
Visualize matrix blockstructure with spectral clustering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Visualization of matrix block structure | |
(eg, pairwise similarity or co-occurrence) | |
Requires scaledimage.py for intensity plots | |
David Andrzejewski | |
""" | |
import numpy as NP | |
import matplotlib.pyplot as P | |
import matplotlib.lines as L | |
import scikits.learn.cluster as SKLC | |
from scaledimage import scaledimage | |
def blockviz(affinity, nclust, ax=None): | |
""" | |
Visualize block-structure of affinity matrix | |
affinity = NxN non-negative affinity matrix | |
nclust = number of clusters to use | |
ax = matplotlib Axes to draw on | |
Rely on caller to .show() | |
""" | |
# Activate the appropriate Axes in pyplot | |
if(ax == None): | |
ax = P.figure().gca() | |
# Do spectral clustering | |
ndata = affinity.shape[0] | |
c = SKLC.SpectralClustering(k=nclust, mode='amg') | |
c.fit(affinity) | |
# Extract cluster labels and sort indices to align with clusters | |
sortidx = [] | |
for ki in range(nclust): | |
sortidx += getlabeled(c.labels_, ki) | |
sorted_affinity = affinity.copy() | |
sorted_affinity = sorted_affinity[sortidx,:] | |
sorted_affinity = sorted_affinity[:,sortidx] | |
# Intensity plot of affinity | |
width = 3 | |
scaledimage(sorted_affinity, | |
pixwidth=width, grayscale=True, ax=ax) | |
# Draw recovered cluster boundaries | |
kstart = 0 | |
for ki in range(nclust): | |
clustki = getlabeled(c.labels_, ki) | |
drawClust(ax, | |
kstart, | |
kstart + len(clustki), | |
ndata, | |
scale=width) | |
kstart += len(clustki) | |
return c.labels_ | |
def logistic(val): | |
""" Logistic function """ | |
return float(1) / (1 + NP.exp(-1 * val)) | |
def getlabeled(labels, ki): | |
""" Get indices where labels==ki """ | |
return [idx for (idx, val) in | |
enumerate(labels) if val == ki] | |
def drawH(ax, y, xstart, xend): | |
""" Draw horiztonal line """ | |
ax.add_line(L.Line2D([xstart, xend], | |
[y, y], | |
color='r')) | |
def drawV(ax, x, ystart, yend): | |
""" Draw vertical line """ | |
ax.add_line(L.Line2D([x, x], | |
[ystart, yend], | |
color='r')) | |
def drawClust(ax, kstart, kend, kmax, scale=1): | |
""" Draw bounding box for cluster """ | |
skstart = scale * kstart | |
skend = scale * kend | |
skmax = scale * kmax | |
if(skstart == 0): | |
# Upper-left cluster: only draw bottom-right borders | |
drawH(ax, skmax-skend, skstart, skend) | |
drawV(ax, skend, skmax-skstart, skmax-skend) | |
elif(skend == skmax): | |
# Lower-right cluster: only draw top-left borders | |
drawH(ax, skmax-skstart, skstart, skend) | |
drawV(ax, skstart, skmax-skstart, skmax-skend) | |
else: | |
# Otherwise, draw all 4 borders | |
drawH(ax, skmax-skend, skstart, skend) | |
drawV(ax, skend, skmax-skstart, skmax-skend) | |
drawH(ax, skmax-skstart, skstart, skend) | |
drawV(ax, skstart, skmax-skstart, skmax-skend) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment