Created
February 8, 2015 16:14
-
-
Save schocco/1d519e6554fbb0066c5a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import sys, getopt | |
def update_progress(progress, barLength = 12, status = ""): | |
'display a progress bar in the console and update in-place' | |
if isinstance(progress, int): | |
progress = float(progress) | |
if not isinstance(progress, float): | |
progress = 0 | |
raise ValueError("progress bar must be float\n") | |
if progress >= 1: | |
progress = 1 | |
status = "Done...\n" | |
block = int(round(barLength*progress)) | |
text = "\rProgress: [{0}] {1:.1f}% {2}".format( "#"*block + "-"*(barLength-block), progress*100, status) | |
sys.stdout.write(text) | |
sys.stdout.flush() | |
def fill_p(nodes): | |
''' | |
The probabilities table can be calculated completely with the initial p values. | |
''' | |
update_progress(0, status="pre-filling probabilities table...") | |
n = len(nodes) | |
probabilities = [x[:] for x in [[0]*(n)]*(n)] | |
# fill 1st diagonal | |
for i in range(n): | |
probabilities[i][i] = nodes[i] | |
for i in range(1,n): | |
for j in range(n-i): | |
l = i+j | |
a = probabilities[j][j] | |
b = probabilities[j+1][l] | |
probabilities[j][l] = a + b | |
return probabilities | |
def obst(nodes, mode="qubic"): | |
''' | |
Uses Knuth's dynamic algorithm to calculate the table of roots and weights of all subtrees. | |
Basic O(n^3) form of algorithm is: | |
obst(i,j) = min_from_i_to_j{obst(i, r-1) + obst(r + 1,j) + sum_k_to_i(p[k])} | |
obst(i,j) = min_from_i_to_j{obst(i, r-1) + obst(r + 1,j) + probabilities[i][j]} | |
Instead of i to j, it is sufficient to calculate r[i,j-1]<=k<=r[i+1,j] which makes | |
the complexity O(n^2) | |
''' | |
# start with empty table | |
n = len(nodes) | |
roots = [x[:] for x in [[None]*(n)]*(n)] | |
weights = [x[:] for x in [[None]*(n)]*(n)] | |
probabilities = fill_p(nodes) | |
# fill 1st diagonal | |
for i in range(n): | |
roots[i][i] = i | |
weights[i][i] = nodes[i] | |
# calculate values for other cells | |
permil = n / 1000 | |
for i in range(1,n): | |
if(i % permil == 0): | |
update_progress(float(i/n)) | |
for j in range(n-i): | |
l = i+j | |
summ = None | |
winning_root = None | |
if mode == "qubic": | |
rs = range(j, l+1) | |
else: | |
rs = range(roots[j][l-1], roots[j+1][l]+1) | |
for r in rs: # +1 because it needs to be inclusive | |
a = (j, r-1) # index for left obst | |
aval = 0 # value of left obst | |
b = (r+1, l) # index for right obst | |
bval = 0 # value of right obst | |
if a[1] >= 0 and a[0] <= a[1]: | |
aval = weights[a[0]][a[1]] | |
if b[0] <= b[1]: | |
bval = weights[b[0]][b[1]] | |
sum_r = aval + bval | |
if summ is None or sum_r < summ: | |
summ = sum_r | |
winning_root = r | |
roots[j][l] = winning_root | |
weights[j][l] = summ + probabilities[j][l] | |
update_progress(1.0) | |
return weights,roots | |
def read_nodes_from_file(path): | |
'reads comma separated numbers from text file' | |
nums = open(path, "r").read() | |
return tuple([int(n) for n in nums.strip().split(",")]) | |
def main(argv=[]): | |
inputfile = "" | |
helptext = 'obst.py -i <inputfile>' | |
try: | |
opts, args = getopt.getopt(argv, "hi:", ["ifile="]) | |
except getopt.GetoptError: | |
print(helptext) | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt == '-h': | |
print(helptext) | |
sys.exit() | |
elif opt in ("-i", "--ifile"): | |
inputfile = arg | |
try: | |
nodes = read_nodes_from_file(path=inputfile) | |
except FileNotFoundError as err: | |
print(err) | |
print(helptext) | |
sys.exit(2) | |
print("calculating optimal binary search tree with n=%d" % len(nodes)) | |
weights,roots = obst(nodes, mode="quadratic") | |
root = roots[0][len(roots)-1] | |
weight = weights[0][len(weights)-1] | |
for row in weights: | |
for char in row: | |
print(char, end="\t") | |
print("") | |
print("\nRoot node is {0} (p({0})={1}, p_max = {2}), weighted inner path length is {3}".format(root+1, nodes[root], max(nodes), weight)) | |
if __name__ == '__main__': | |
if len(sys.argv) > 1: | |
main(sys.argv[1:]) | |
else: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment