Created
December 10, 2015 13:42
-
-
Save bagrow/7a99ecf6891099f47e01 to your computer and use it in GitHub Desktop.
Improve the placement of labels on a busy scatterplot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# annotate_with_graphlayout.py | |
# Jim Bagrow | |
# Last Modified: 2015-12-09 | |
""" | |
Rough improvement of scatter point label placement using | |
force-directed graph layout. | |
See also Stack Overflow: http://bit.ly/1NXgB3o | |
""" | |
import sys, os | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
from itertools import combinations | |
import numpy as np | |
# these constants likely require tuning for different data: | |
dx,dy = 0.05,0.05 # initial offset of labels from points | |
dX,dY = 0.8, 0.8 # expand the frame of the plot to make room for stretched labels | |
WEIGHT = 10.0 # strength of links, how far from points can text be? | |
K = 0.5 # preferred distance between points | |
label_repulsive_weight = 0.001 # labels push away from each other | |
# some fake data: | |
N = 30 | |
X = np.random.randn(N) | |
Y = X*0.5 + np.random.randn(N) | |
C = np.random.rand(N) | |
# build the network: | |
d_nodes, t_nodes = [], [] | |
G = nx.Graph() | |
node2coord = {} | |
for i in range(N): | |
x, y = X[i], Y[i] | |
d_str = "d%i" % i | |
t_str = "t%i" % i | |
d_nodes.append(d_str) | |
t_nodes.append(t_str) | |
G.add_edge(d_str, t_str, weight=WEIGHT) | |
node2coord[d_str] = (x, y) | |
node2coord[t_str] = (x+dx, y+dy) | |
# "t" nodes are self-repulsive: | |
for ni,nj in combinations(t_nodes,2): | |
G.add_edge(ni,nj, weight=-label_repulsive_weight*WEIGHT) | |
# compute new layout, only "t" nodes can move: | |
node2coord_springs = nx.spring_layout(G, k=K, pos=node2coord, fixed=d_nodes) | |
def draw_and_label(coords, arrows=False): | |
ax = plt.gca() | |
ax.scatter(X, Y, c=C, s=C*200) | |
arrowprops = None | |
if arrows: | |
arrowprops = dict(facecolor='black', shrink=0.05, width=0.5, headwidth=0.5) | |
for i in range(N): | |
d_str = "d%i" % i | |
t_str = "t%i" % i | |
ax.annotate(t_str, | |
xy=coords[d_str], | |
xytext=coords[t_str], | |
arrowprops=arrowprops, | |
) | |
ax.set_xlim(X.min()-dX, X.max()+dX) | |
ax.set_ylim(Y.min()-dY, Y.max()+dY) | |
plt.figure(figsize=(12,6)) | |
plt.subplot(121) | |
draw_and_label(node2coord) | |
plt.subplot(122) | |
draw_and_label(node2coord_springs, arrows=True) | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is pretty cool! thank-you!