Created
April 24, 2011 20:11
-
-
Save davidandrzej/939840 to your computer and use it in GitHub Desktop.
3-simplex triangular scatter plot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Visualize points on the 3-simplex (eg, the parameters of a | |
3-dimensional multinomial distributions) as a scatter plot | |
contained within a 2D triangle. | |
David Andrzejewski ([email protected]) | |
""" | |
import numpy as NP | |
import matplotlib.pyplot as P | |
import matplotlib.ticker as MT | |
import matplotlib.lines as L | |
import matplotlib.cm as CM | |
import matplotlib.colors as C | |
import matplotlib.patches as PA | |
def plotSimplex(points, fig=None, | |
vertexlabels=['1','2','3'], | |
**kwargs): | |
""" | |
Plot Nx3 points array on the 3-simplex | |
(with optionally labeled vertices) | |
kwargs will be passed along directly to matplotlib.pyplot.scatter | |
Returns Figure, caller must .show() | |
""" | |
if(fig == None): | |
fig = P.figure() | |
# Draw the triangle | |
l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords | |
[0, NP.sqrt(3) / 2, 0, 0], # ycoords | |
color='k') | |
fig.gca().add_line(l1) | |
fig.gca().xaxis.set_major_locator(MT.NullLocator()) | |
fig.gca().yaxis.set_major_locator(MT.NullLocator()) | |
# Draw vertex labels | |
fig.gca().text(-0.05, -0.05, vertexlabels[0]) | |
fig.gca().text(1.05, -0.05, vertexlabels[1]) | |
fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2]) | |
# Project and draw the actual points | |
projected = projectSimplex(points) | |
P.scatter(projected[:,0], projected[:,1], **kwargs) | |
# Leave some buffer around the triangle for vertex labels | |
fig.gca().set_xlim(-0.2, 1.2) | |
fig.gca().set_ylim(-0.2, 1.2) | |
return fig | |
def projectSimplex(points): | |
""" | |
Project probabilities on the 3-simplex to a 2D triangle | |
N points are given as N x 3 array | |
""" | |
# Convert points one at a time | |
tripts = NP.zeros((points.shape[0],2)) | |
for idx in range(points.shape[0]): | |
# Init to triangle centroid | |
x = 1.0 / 2 | |
y = 1.0 / (2 * NP.sqrt(3)) | |
# Vector 1 - bisect out of lower left vertex | |
p1 = points[idx, 0] | |
x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6) | |
y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6) | |
# Vector 2 - bisect out of lower right vertex | |
p2 = points[idx, 1] | |
x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6) | |
y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6) | |
# Vector 3 - bisect out of top vertex | |
p3 = points[idx, 2] | |
y = y + (1.0 / NP.sqrt(3) * p3) | |
tripts[idx,:] = (x,y) | |
return tripts | |
if __name__ == '__main__': | |
# Define a synthetic test dataset | |
labels = ('[0.1 0.1 0.8]', | |
'[0.8 0.1 0.1]', | |
'[0.5 0.4 0.1]', | |
'[0.33 0.34 0.33]') | |
testpoints = NP.array([[0.1, 0.1, 0.8], | |
[0.8, 0.1, 0.1], | |
[0.5, 0.4, 0.1], | |
[0.33, 0.34, 0.33]]) | |
# Define different colors for each label | |
cmap = CM.get_cmap('spectral') | |
norm = C.Normalize(vmin=0, vmax=len(labels)) | |
c = range(len(labels)) | |
# Do scatter plot | |
fig = plotSimplex(testpoints, s=100, c=c, | |
cmap=cmap, norm=norm) | |
# Make color-label legend | |
P.legend([PA.Rectangle((0, 0), 1, 1, | |
fc=cmap(norm(idx))) | |
for idx in range(len(labels))], | |
labels) | |
P.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment