Created
August 4, 2012 18:18
-
-
Save kflu/3259142 to your computer and use it in GitHub Desktop.
Find all paths in a binary tree that the nodes on the path sums up to a specified vale.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Find all paths in a binary tree that the nodes on the path sums up to a specified vale. | |
""" | |
from collections import deque | |
def find_path(T,X): | |
""" | |
On each node it visits, it marks the node's parent. This is for tracing | |
upwards. | |
""" | |
res = [] | |
s = deque() | |
s.append(T) | |
T.par = None | |
while len(s) != 0: | |
cur = s.pop() | |
tmp = X | |
p = cur | |
while p: | |
tmp -= p.data | |
if tmp == 0: | |
res.append((p, cur)) | |
p = p.par | |
if cur.R: | |
cur.R.par = cur | |
s.append(cur.R) | |
if cur.L: | |
cur.L.par = cur | |
s.append(cur.L) | |
return res | |
def find_path_2(T,X): | |
"""variable p contains the path from root to the current node. There's no need | |
to mark each node | |
""" | |
res = [] | |
s = deque() | |
p = deque() | |
s.append(T) | |
while len(s) != 0: | |
cur = s.pop() | |
if not cur: | |
p.pop() | |
continue | |
p.append(cur) | |
s.append(None) | |
if cur.R: s.append(cur.R) | |
if cur.L: s.append(cur.L) | |
tmp = X | |
for i in xrange(len(p)-1, -1, -1): | |
tmp -= p[i].data | |
if tmp == 0: | |
res.append([p[i], p[-1]]) | |
return res | |
class Node: | |
def __init__(self, data=None, L=None, R=None): | |
self.data = data | |
self.L = L | |
self.R = R | |
def get_tree(): | |
T = Node(1, | |
Node(2, | |
Node(4), | |
Node(5, | |
None, | |
Node(-5))), | |
Node(3)) | |
cur = T | |
s = deque() | |
s.append(T) | |
while len(s) != 0: | |
cur = s.pop() | |
if cur.R: | |
cur.R.par = cur | |
s.append(cur.R) | |
if cur.L: | |
cur.L.par = cur | |
s.append(cur.L) | |
return T | |
def find_path_test(algo): | |
def f(): | |
T = get_tree() | |
assert algo(T, 100) ==[] | |
res = algo(T,3) | |
assert len(res) == 3 | |
for p in res: | |
s = 0 | |
cur,start= p[::-1] | |
while cur != start: | |
s += cur.data | |
cur = cur.par | |
s += start.data | |
assert s == 3 | |
return f | |
test_1 = find_path_test(find_path) | |
test_2 = find_path_test(find_path_2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment