Skip to content

Instantly share code, notes, and snippets.

@deusebio
Last active December 14, 2021 13:37
Show Gist options
  • Save deusebio/397610ac1f12dfd77ee6412987deb963 to your computer and use it in GitHub Desktop.
Save deusebio/397610ac1f12dfd77ee6412987deb963 to your computer and use it in GitHub Desktop.
Wishing a merry Christmas with networkx
from typing import Any, Tuple, Dict
import networkx as nx
import numpy as np
def generate_Xmas_tree(depth: int) -> nx.Graph:
nodes = set(range(sum(i for i in range(depth+1))))
# Create the tree
start=0
edges = []
for i in range(depth-1):
start+=i
stride=i+1
for n in range(stride):
base = start+n
edges.extend([(base, base+stride),(base, base+stride+1)])
for n in range(stride):
edges.append((start+n+stride,start+n+stride+1))
top = max(nodes)
# Create the tree stump
if depth % 2 == 0:
edges.extend([(top-depth/2, top+1), (top-depth/2+1, top+2)])
else:
edges.extend([(top-(depth-1)/2, top+1)])
G = nx.Graph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)
return G
def get_Xmas_layout(G: nx.Graph, height=100, width=100) -> Dict[Any, Tuple[float, float]]:
def getDepth(n: int):
i = 0
tot = 0
while (tot+2<n):
i += 1
tot += i
return i
depth = getDepth(len(G.nodes()))
spacing_height = height / depth
spacing_width = width / depth
start = 0
pos = []
for i in range(depth):
nodes = [start+n for n in range(i+1)]
start += len(nodes)
if i % 2 == 0:
starting_point = - i/2.0 * spacing_width
else:
starting_point = ( - (i+1)/2.0 + 0.5 ) * spacing_width
pos.extend((node, (starting_point + ith * spacing_width, i*spacing_height))
for ith, node in enumerate(nodes))
missing = set(G.nodes()).difference(range(start))
pos.extend([(node,
((0 if depth % 2 == 1 else -0.5*spacing_width) + nth*spacing_width, depth*spacing_height))
for nth, node in enumerate(missing)
])
return {node: (posx, height - posy) for node, (posx, posy) in pos}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment