Last active
July 9, 2019 19:34
-
-
Save laurencee9/5be78ae14e49ae80b9e6b433cb097be4 to your computer and use it in GitHub Desktop.
Draw curved network edges in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.patches as mpatches | |
import networkx as nx | |
from matplotlib.patches import FancyArrowPatch as ArrowPath | |
from matplotlib.font_manager import FontProperties | |
def draw_curved_edges(edges, pos, ax, mu=0.05, edge_color="black", edge_width=1.0, alpha=1.0, arrow_scale=20.0, loopsize=0): | |
""" | |
Params | |
--------------------------- | |
edges: Edge list | |
pos : Dict of position of nodes | |
ax : Axes | |
mu : Control the curvature | |
edge_color : Color or list of color for each edge | |
edge_width : Double | |
alpha : Opacity | |
loopsize: Size of self loops | |
""" | |
for v,edge in enumerate(edges): | |
x1,y1 = pos[edge[0]] | |
x2,y2 = pos[edge[1]] | |
#If self loop | |
if edge[0]==edge[1]: | |
dv = 0.0 | |
theta = np.linspace(0.0,2.0*np.pi, 100) | |
X = np.cos(theta)*loopsize+x1-loopsize | |
Y = np.sin(theta)*loopsize+y1 | |
else: | |
dv = mu | |
if x1>x2: | |
x1,y1 = pos[edge[1]] | |
x2,y2 = pos[edge[0]] | |
dv = -mu | |
dx = x2 | |
dy = y2 | |
# Center to origin | |
x1 -= dx | |
x2 -= dx | |
y1 -= dy | |
y2 -= dy | |
r = ((y2-y1)**2.0 + (x2-x1)**2.0)**(0.5) | |
theta = np.arctan2(y1, x1) | |
c, s = np.cos(theta), np.sin(theta) | |
R = np.matrix('{} {}; {} {}'.format(c, -s, s, c)) | |
# Rotate | |
x1 = x2 - r | |
y1 = y2 | |
# Find parabola | |
h = (x2+x1)/2.0 | |
k = y2+dv*r | |
a = (y2-k)/((x2-h)**2.0) | |
X = np.linspace(x1,x2,100) | |
Y = a*(X-h)**2.0+k | |
# Rotate the parabola and translate | |
theta = np.pi+theta | |
c, s = np.cos(theta), np.sin(theta) | |
R = np.matrix('{} {}; {} {}'.format(c, -s, s, c)) | |
for u in range(len(X)): | |
C1 = np.array([X[u], Y[u]]) | |
C1 = np.dot(R, C1) | |
X[u], Y[u] = C1[0,0]+dx , C1[0,1]+dy | |
color = edge_color | |
if type(edge_color)==list: | |
color = edge_color[v] | |
edgewidth = edge_width | |
if type(edge_width)==list: | |
edgewidth = edge_width[v] | |
middle_index = int(len(X)/2) | |
posA = (X[middle_index-5], Y[middle_index-5]) | |
posB = (X[middle_index+5], Y[middle_index+5]) | |
if dv<0.0: | |
u = posA | |
posA = posB | |
posB = u | |
try: | |
arrow = ArrowPath(posA=posA, posB=posB, mutation_scale=arrow_scale,color=color) | |
ax.add_patch(arrow) | |
except: | |
True | |
ax.plot(X,Y, "-", linewidth=edgewidth, color=color, zorder=0, alpha=alpha) | |
def draw_networks(G, pos, ax, mu=0.08, | |
edge_color="black", | |
edge_width=1.0, | |
edge_alpha=1.0, | |
use_edge_weigth=True, | |
node_width=1.0, | |
node_size=80.0, | |
node_border_color="#404040", | |
node_color="#EDEDED", | |
node_alpha=1.0, | |
arrow_scale=20.0, | |
loop_radius=0.0, | |
letter="", | |
letter_fontsize=13, | |
letter_pos=[0.87, 0.02], | |
letter_color="black"): | |
# Edges | |
weights = [0]*len(G.edges()) | |
for i,edge in enumerate(G.edges()): | |
weights[i] = G[edge[0]][edge[1]]['weight'] if use_edge_weigth else 1.0 | |
draw_curved_edges(G.edges(), pos, ax, | |
mu=mu, | |
edge_color=edge_color, | |
alpha = edge_alpha, | |
arrow_scale=arrow_scale, | |
loopsize=loop_radius, | |
edge_width=weights) | |
# Nodes | |
nodes = nx.draw_networkx_nodes(G, pos, | |
ax=ax, | |
node_size=node_size, | |
node_color=node_color, | |
linewidths=node_width) | |
nodes.set_edgecolor(node_border_color) | |
# Letter | |
font = FontProperties() | |
font.set_weight('bold') | |
ax.text(letter_pos[0], letter_pos[1], letter, | |
verticalalignment='bottom', | |
horizontalalignment='left', | |
transform=ax.transAxes, | |
fontproperties=font, | |
color=letter_color, | |
fontsize=letter_fontsize) | |
# Axis | |
ax.axis('off') | |
return | |
Hey, Thank you for sharing this code... the generated graph is pretty nice.... Please can you give precision about the 'ax' entry... i didn't understand it. its not a classical axis feature, It has attributes like plotting (ax.plot(.)) and so on !!!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Now included in DynamicaLab