Created
April 28, 2018 21:22
-
-
Save jakab922/1e257ed9923d603c6942295a09e05661 to your computer and use it in GitHub Desktop.
Centroid tree of a tree in O(n * log_{2}(n)) time.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict as dd | |
def _calc_size(tree, root, found, sizes): | |
"""Practically a dfs on the remaining nodes | |
where the size of the parent node is the sum of | |
sizes of the child nodes.""" | |
stack = [(root, 0)] | |
on_route = set([root]) | |
while stack: | |
curr, index = stack.pop() | |
if index < len(tree[curr]): | |
stack.append((curr, index + 1)) | |
nxt = tree[curr][index] | |
if not found[nxt] and nxt not in on_route: | |
stack.append((nxt, 0)) | |
on_route.add(nxt) | |
else: | |
sizes[curr] = 1 | |
for neigh in tree[curr]: | |
if found[neigh] and neigh not in on_route: | |
continue | |
sizes[curr] += sizes[neigh] | |
on_route.remove(curr) | |
def _check_condition(tree, cand, found, sizes, psize): | |
""" Checks if all the neighbors of cand have <= | |
size than psize. """ | |
for neigh in tree[cand]: | |
if found[neigh]: | |
continue | |
if sizes[neigh] > psize: | |
return False, neigh | |
return True, None | |
def make_centroid(tree): | |
""" We're assuming here that a tree | |
is given as a list of lists. I.e.: | |
[[...], [...], ..., [...]] | |
It can be proven that this runs in O(n * log_2(n)) | |
Where n is the number of nodes in the tree. | |
""" | |
n = len(tree) | |
ret = [[] for _ in xrange(n)] | |
found = [False for _ in xrange(n)] | |
cands = [(0, None)] | |
while cands: | |
cand, par = cands.pop() | |
found[cand] = True | |
sizes = dd(int) | |
sizes[cand] = 1 | |
for neigh in tree[cand]: | |
if found[neigh]: | |
continue | |
_calc_size(tree, neigh, found, sizes) | |
sizes[cand] += sizes[neigh] | |
found[cand] = False | |
psize = sizes[cand] / 2 | |
if not sizes: | |
if par is not None: | |
ret[par].append(cand) | |
continue | |
good, nxt = _check_condition(tree, cand, found, sizes, psize) | |
while not good: | |
sizes[cand] = 1 | |
for neigh in tree[cand]: | |
if found[neigh] or neigh == nxt: | |
continue | |
sizes[cand] += sizes[neigh] | |
cand = nxt | |
good, nxt = _check_condition(tree, cand, found, sizes, psize) | |
found[cand] = True | |
if par is not None: | |
ret[par].append(cand) | |
for neigh in tree[cand]: | |
if not found[neigh]: | |
cands.append((neigh, cand)) | |
return ret | |
if __name__ == "__main__": | |
# example graph from https://www.geeksforgeeks.org/centroid-decomposition-of-tree/ | |
pre_tree = [ | |
[4], | |
[4], | |
[4], | |
[1, 2, 3, 5], | |
[4, 6], | |
[5, 7, 10], | |
[6, 8, 9], | |
[7], | |
[7], | |
[6, 11], | |
[10, 12, 13], | |
[11, 14], | |
[11, 15, 16], | |
[12], | |
[13], | |
[13] | |
] | |
tree = [map(lambda x: x - 1, el) for el in pre_tree] | |
centroid_tree = make_centroid(tree) | |
post_centroid_tree = [map(lambda x: x + 1, el) for el in centroid_tree] | |
for i, row in enumerate(post_centroid_tree, 1): | |
print i, row |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment