Skip to content

Instantly share code, notes, and snippets.

@jakab922
Created April 28, 2018 21:22
Show Gist options
  • Save jakab922/1e257ed9923d603c6942295a09e05661 to your computer and use it in GitHub Desktop.
Save jakab922/1e257ed9923d603c6942295a09e05661 to your computer and use it in GitHub Desktop.
Centroid tree of a tree in O(n * log_{2}(n)) time.
from collections import defaultdict as dd
def _calc_size(tree, root, found, sizes):
"""Practically a dfs on the remaining nodes
where the size of the parent node is the sum of
sizes of the child nodes."""
stack = [(root, 0)]
on_route = set([root])
while stack:
curr, index = stack.pop()
if index < len(tree[curr]):
stack.append((curr, index + 1))
nxt = tree[curr][index]
if not found[nxt] and nxt not in on_route:
stack.append((nxt, 0))
on_route.add(nxt)
else:
sizes[curr] = 1
for neigh in tree[curr]:
if found[neigh] and neigh not in on_route:
continue
sizes[curr] += sizes[neigh]
on_route.remove(curr)
def _check_condition(tree, cand, found, sizes, psize):
""" Checks if all the neighbors of cand have <=
size than psize. """
for neigh in tree[cand]:
if found[neigh]:
continue
if sizes[neigh] > psize:
return False, neigh
return True, None
def make_centroid(tree):
""" We're assuming here that a tree
is given as a list of lists. I.e.:
[[...], [...], ..., [...]]
It can be proven that this runs in O(n * log_2(n))
Where n is the number of nodes in the tree.
"""
n = len(tree)
ret = [[] for _ in xrange(n)]
found = [False for _ in xrange(n)]
cands = [(0, None)]
while cands:
cand, par = cands.pop()
found[cand] = True
sizes = dd(int)
sizes[cand] = 1
for neigh in tree[cand]:
if found[neigh]:
continue
_calc_size(tree, neigh, found, sizes)
sizes[cand] += sizes[neigh]
found[cand] = False
psize = sizes[cand] / 2
if not sizes:
if par is not None:
ret[par].append(cand)
continue
good, nxt = _check_condition(tree, cand, found, sizes, psize)
while not good:
sizes[cand] = 1
for neigh in tree[cand]:
if found[neigh] or neigh == nxt:
continue
sizes[cand] += sizes[neigh]
cand = nxt
good, nxt = _check_condition(tree, cand, found, sizes, psize)
found[cand] = True
if par is not None:
ret[par].append(cand)
for neigh in tree[cand]:
if not found[neigh]:
cands.append((neigh, cand))
return ret
if __name__ == "__main__":
# example graph from https://www.geeksforgeeks.org/centroid-decomposition-of-tree/
pre_tree = [
[4],
[4],
[4],
[1, 2, 3, 5],
[4, 6],
[5, 7, 10],
[6, 8, 9],
[7],
[7],
[6, 11],
[10, 12, 13],
[11, 14],
[11, 15, 16],
[12],
[13],
[13]
]
tree = [map(lambda x: x - 1, el) for el in pre_tree]
centroid_tree = make_centroid(tree)
post_centroid_tree = [map(lambda x: x + 1, el) for el in centroid_tree]
for i, row in enumerate(post_centroid_tree, 1):
print i, row
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment