Skip to content

Instantly share code, notes, and snippets.

@davipatti
Created April 14, 2021 18:06
Show Gist options
  • Save davipatti/2c2b2376865cfc59b39a2ade85942bd1 to your computer and use it in GitHub Desktop.
Save davipatti/2c2b2376865cfc59b39a2ade85942bd1 to your computer and use it in GitHub Desktop.
import sys
import csv
import networkx as nx
import matplotlib.pyplot as plt
"""
Expects CSV on stdin, containing weights of edges
,a,b,c
a,,2,3
b,,,1
c,,,
"""
g = nx.Graph()
row_totals = []
with sys.stdin as fobj:
lines = csv.reader(fobj)
columns = next(lines)[1:]
for row, *weights in lines:
for column, weight in zip(columns, weights):
if weight:
g.add_edge(column, row, weight=float(weight) ** 4)
pos = nx.spring_layout(g, scale=0.5)
xmax = max(v[0] for v in pos.values())
xmin = min(v[0] for v in pos.values())
ymax = max(v[1] for v in pos.values())
ymin = min(v[1] for v in pos.values())
pad = 0.5
fig = plt.figure(figsize=(10, 10))
nx.draw_networkx_nodes(g, pos)
nx.draw_networkx_edges(
g, pos, width=[g[u][v]["weight"] ** 0.25 for u, v in g.edges()]
)
nx.draw_networkx_labels(g, pos)
plt.axis("off")
plt.xlim(xmin - pad, xmax + pad)
plt.ylim(ymin - pad, ymax + pad)
plt.savefig("graph.pdf", bbox_inches="tight")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment