Skip to content

Instantly share code, notes, and snippets.

@laurencee9
Last active July 9, 2019 19:34
Show Gist options
  • Save laurencee9/5be78ae14e49ae80b9e6b433cb097be4 to your computer and use it in GitHub Desktop.
Save laurencee9/5be78ae14e49ae80b9e6b433cb097be4 to your computer and use it in GitHub Desktop.
Draw curved network edges in python
import numpy as np
import matplotlib.patches as mpatches
import networkx as nx
from matplotlib.patches import FancyArrowPatch as ArrowPath
from matplotlib.font_manager import FontProperties
def draw_curved_edges(edges, pos, ax, mu=0.05, edge_color="black", edge_width=1.0, alpha=1.0, arrow_scale=20.0, loopsize=0):
"""
Params
---------------------------
edges: Edge list
pos : Dict of position of nodes
ax : Axes
mu : Control the curvature
edge_color : Color or list of color for each edge
edge_width : Double
alpha : Opacity
loopsize: Size of self loops
"""
for v,edge in enumerate(edges):
x1,y1 = pos[edge[0]]
x2,y2 = pos[edge[1]]
#If self loop
if edge[0]==edge[1]:
dv = 0.0
theta = np.linspace(0.0,2.0*np.pi, 100)
X = np.cos(theta)*loopsize+x1-loopsize
Y = np.sin(theta)*loopsize+y1
else:
dv = mu
if x1>x2:
x1,y1 = pos[edge[1]]
x2,y2 = pos[edge[0]]
dv = -mu
dx = x2
dy = y2
# Center to origin
x1 -= dx
x2 -= dx
y1 -= dy
y2 -= dy
r = ((y2-y1)**2.0 + (x2-x1)**2.0)**(0.5)
theta = np.arctan2(y1, x1)
c, s = np.cos(theta), np.sin(theta)
R = np.matrix('{} {}; {} {}'.format(c, -s, s, c))
# Rotate
x1 = x2 - r
y1 = y2
# Find parabola
h = (x2+x1)/2.0
k = y2+dv*r
a = (y2-k)/((x2-h)**2.0)
X = np.linspace(x1,x2,100)
Y = a*(X-h)**2.0+k
# Rotate the parabola and translate
theta = np.pi+theta
c, s = np.cos(theta), np.sin(theta)
R = np.matrix('{} {}; {} {}'.format(c, -s, s, c))
for u in range(len(X)):
C1 = np.array([X[u], Y[u]])
C1 = np.dot(R, C1)
X[u], Y[u] = C1[0,0]+dx , C1[0,1]+dy
color = edge_color
if type(edge_color)==list:
color = edge_color[v]
edgewidth = edge_width
if type(edge_width)==list:
edgewidth = edge_width[v]
middle_index = int(len(X)/2)
posA = (X[middle_index-5], Y[middle_index-5])
posB = (X[middle_index+5], Y[middle_index+5])
if dv<0.0:
u = posA
posA = posB
posB = u
try:
arrow = ArrowPath(posA=posA, posB=posB, mutation_scale=arrow_scale,color=color)
ax.add_patch(arrow)
except:
True
ax.plot(X,Y, "-", linewidth=edgewidth, color=color, zorder=0, alpha=alpha)
def draw_networks(G, pos, ax, mu=0.08,
edge_color="black",
edge_width=1.0,
edge_alpha=1.0,
use_edge_weigth=True,
node_width=1.0,
node_size=80.0,
node_border_color="#404040",
node_color="#EDEDED",
node_alpha=1.0,
arrow_scale=20.0,
loop_radius=0.0,
letter="",
letter_fontsize=13,
letter_pos=[0.87, 0.02],
letter_color="black"):
# Edges
weights = [0]*len(G.edges())
for i,edge in enumerate(G.edges()):
weights[i] = G[edge[0]][edge[1]]['weight'] if use_edge_weigth else 1.0
draw_curved_edges(G.edges(), pos, ax,
mu=mu,
edge_color=edge_color,
alpha = edge_alpha,
arrow_scale=arrow_scale,
loopsize=loop_radius,
edge_width=weights)
# Nodes
nodes = nx.draw_networkx_nodes(G, pos,
ax=ax,
node_size=node_size,
node_color=node_color,
linewidths=node_width)
nodes.set_edgecolor(node_border_color)
# Letter
font = FontProperties()
font.set_weight('bold')
ax.text(letter_pos[0], letter_pos[1], letter,
verticalalignment='bottom',
horizontalalignment='left',
transform=ax.transAxes,
fontproperties=font,
color=letter_color,
fontsize=letter_fontsize)
# Axis
ax.axis('off')
return
@laurencee9
Copy link
Author

example

@laurencee9
Copy link
Author

With arrows

image

@laurencee9
Copy link
Author

Now included in DynamicaLab

@HadjAhmed
Copy link

Hey, Thank you for sharing this code... the generated graph is pretty nice.... Please can you give precision about the 'ax' entry... i didn't understand it. its not a classical axis feature, It has attributes like plotting (ax.plot(.)) and so on !!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment