Last active
September 3, 2021 04:26
-
-
Save hsteinshiromoto/e2e25814104004a4516e65023da5c8e6 to your computer and use it in GitHub Desktop.
nx.plotgraph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Iterable | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import numpy as np | |
def make_graph(nodes: Iterable, M: np.ndarray, G: nx.classes.digraph.DiGraph=nx.DiGraph()): | |
"""Build graph based on list of nodes and a weight matrix | |
Args: | |
nodes (list): Graph nodes | |
M (np.ndarray): Weight matrix | |
G (nx.classes.digraph.DiGraph, optional): Graph type. Defaults to nx.DiGraph(). | |
Returns: | |
[type]: Graph object | |
Example: | |
>>> n_nodes = 4 | |
>>> M = np.random.rand(n_nodes, n_nodes) | |
>>> nodes = range(M.shape[0]) | |
>>> G = make_graph(nodes, M) | |
""" | |
for node in nodes: | |
G.add_node(node, label=f"{node}") | |
for i, origin_node in enumerate(nodes): | |
for j, destination_node in enumerate(nodes): | |
if M[i, j] != 0: | |
G.add_edge(origin_node, destination_node, weight=M[i, j] | |
,label=f"{M[i, j]:0.02f}") | |
return G | |
def graphplot(G: nx.classes.digraph.DiGraph, M: np.ndarray | |
,min_weight_threshold: float=0.0, bins: int=4 | |
,graph_layout: str="spring_layout" | |
,figsize: tuple=(20, 10) | |
,cmap=plt.cm.coolwarm | |
,edge_kwargs=None, node_label_kwargs=None, node_kwargs=None | |
): | |
"""Plot a graph with weights on edges | |
Args: | |
G (nx.classes.digraph.DiGraph): Weighted graph | |
M (np.ndarray): Weight matrix | |
min_weight_threshold (float, optional): Minimal weight to be plotted. Defaults to 0.0. | |
bins (int, optional): Number of bins to divide the weights. Defaults to 4. | |
graph_layout (str, optional): Defaults to "spring_layout". | |
figsize (tuple, optional): Defaults to (20, 10). | |
cmap ([type], optional): Matplotlib colormap. Defaults to plt.cm.coolwarm. | |
edge_kwargs ([type], optional): Kwargs to edge plot. Defaults to None. | |
Returns: | |
ax: Plotted graph | |
Example: | |
>>> n_nodes = 4 | |
>>> M = np.random.rand(n_nodes, n_nodes) | |
>>> nodes = range(M.shape[0]) | |
>>> G = make_graph(nodes, M) | |
>>> graphplot(G, M) | |
References: | |
[1] https://networkx.org/documentation/stable/auto_examples/drawing/plot_directed.html | |
""" | |
node_kwargs = node_kwargs or {"node_color": "k", "node_size": 500} | |
edge_kwargs = edge_kwargs or {"edge_color" :nx.get_edge_attributes(G, 'weight').values() | |
,"edge_cmap": cmap | |
,"width": 4 | |
,"connectionstyle":'arc3, rad=0.2' | |
} | |
node_label_kwargs = node_label_kwargs or {"font_color": "w", "font_size": 16 | |
,"font_weight": "bold" | |
} | |
pos = getattr(nx, graph_layout)(G) | |
fig, ax = plt.subplots(figsize=figsize) | |
nx.draw_networkx_nodes(G, pos, ax=ax, **node_kwargs) | |
nx.draw_networkx_labels(G, pos, labels=nx.get_node_attributes(G, 'label') | |
,ax=ax, **node_label_kwargs) | |
edges = nx.draw_networkx_edges(G, pos, ax=ax, **edge_kwargs) | |
# Configure colorbar | |
_, bin_edges = np.histogram( | |
np.ma.masked_array(M, mask=M==min_weight_threshold).compressed() | |
,bins=bins) | |
pc = mpl.collections.PatchCollection(edges, cmap=cmap) | |
cmap_array = list(bin_edges) | |
pc.set_array(cmap_array) | |
cbar = plt.colorbar(pc); | |
cbar.set_label('weights', rotation=270, fontsize=16, labelpad=20) | |
# ax = plt.gca() | |
ax.set_axis_off() | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment