Last active
July 9, 2017 01:26
-
-
Save darkhipo/10574513a9c4456b4c53e91bbc9af389 to your computer and use it in GitHub Desktop.
Maximum Flow, Painfully Slow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
def fast_get_mfps_flow(mfps): | |
flow_from_s = {n for n in mfps.G.setOfNodes if n.uid == mfps.sourceNodeUid}.pop().datum.flowOut | |
flow_to_t = {n for n in mfps.G.setOfNodes if n.uid == mfps.terminalNodeUid}.pop().datum.flowIn | |
if( (flow_to_t - flow_from_s) > 0.00): | |
raise Exception('Infeasible s-t flow') | |
return flow_to_t | |
def fast_e_k_preprocess(G): | |
G = strip_flows(G) | |
get = dict({}) | |
get['nodes'] = dict({}) | |
get['node_to_node_capacity'] = dict({}) | |
get['node_to_node_flow'] = dict({}) | |
get['arcs'] = dict({}) | |
get['residual_arcs'] = dict({}) | |
for a in G.setOfArcs: | |
if(a.fromNode not in G.setOfNodes): | |
err_msg = 'There is no Node {a.fromNode.uid!s} to match the Arc from {a.fromNode.uid!s} to {a.toNode.uid!s}'.format(**locals()) | |
raise KeyError(err_msg) | |
if(a.toNode not in G.setOfNodes): | |
err_msg = 'There is no Node {a.toNode.uid!s} to match the Arc from {a.fromNode.uid!s} to {a.toNode.uid!s}'.format(**locals()) | |
raise KeyError(err_msg) | |
get['nodes'][a.fromNode.uid] = a.fromNode | |
get['nodes'][a.toNode.uid] = a.toNode | |
lark = Arc(a.fromNode.uid, a.toNode.uid, FlowArcDatumWithUid(a.datum.capacity, a.datum.flow, uuid.uuid4())) | |
if(a.fromNode.uid not in get['arcs']): | |
get['arcs'][a.fromNode.uid] = dict({a.toNode.uid : dict({lark.datum.uid : lark})}) | |
else: | |
if(a.toNode.uid not in get['arcs'][a.fromNode.uid]): | |
get['arcs'][a.fromNode.uid][a.toNode.uid] = dict({lark.datum.uid : lark}) | |
else: | |
get['arcs'][a.fromNode.uid][a.toNode.uid][lark.datum.uid] = lark | |
for a in G.setOfArcs: | |
if a.toNode.uid not in get['arcs']: | |
get['arcs'][a.toNode.uid] = dict({}) | |
for n in get['nodes']: | |
get['residual_arcs'][n] = dict() | |
get['node_to_node_capacity'][n] = dict() | |
get['node_to_node_flow'][n] = dict() | |
for u in get['nodes']: | |
n_to_u_cap_sum = sum(a.datum.capacity for a in G.setOfArcs if (a.fromNode.uid == n) and (a.toNode.uid == u) ) | |
n_to_u_flow_sum = sum(a.datum.flow for a in G.setOfArcs if (a.fromNode.uid == n) and (a.toNode.uid == u) ) | |
if(n_to_u_cap_sum > n_to_u_flow_sum): | |
flow = round(n_to_u_cap_sum - n_to_u_flow_sum, TOL) | |
get['residual_arcs'][n][u] = Arc(n,u,ResidualDatum(flow, 'push')) | |
if(n_to_u_flow_sum > 0.0): | |
flow = round(n_to_u_flow_sum, TOL) | |
get['residual_arcs'][u][n] = Arc(u,n,ResidualDatum(flow, 'pull')) | |
get['node_to_node_capacity'][n][u] = n_to_u_cap_sum | |
get['node_to_node_flow'][n][u] = n_to_u_flow_sum | |
return get | |
def fast_bfs(sid, tid, get): | |
parent_of = dict([]) | |
visited = frozenset([]) | |
deq = coll.deque([sid]) | |
while len(deq) > 0: | |
n = deq.popleft() | |
if n == tid: | |
break | |
for u in get['residual_arcs'][n]: | |
if (u not in visited): | |
visited = visited.union(frozenset({u})) | |
parent_of[u] = get['residual_arcs'][n][u] | |
deq.extend([u]) | |
path = list([]) | |
curr = tid | |
while(curr != sid): | |
if (curr not in parent_of): | |
err_msg = 'No augmenting path from {} to {}.'.format(sid, curr) | |
raise StopIteration(err_msg) | |
path.extend([parent_of[curr]]) | |
curr = parent_of[curr].fromNode | |
path.reverse() | |
return path | |
def fast_edmonds_karp(mfps): | |
sid, tid = mfps.sourceNodeUid, mfps.terminalNodeUid | |
get = fast_e_k_preprocess(mfps.G) | |
no_more_paths, loop_count = False, 0 | |
while(not no_more_paths): | |
try: | |
apath = fast_bfs(sid, tid, get) | |
get = fast_augment(apath, get) | |
loop_count += 1 | |
except StopIteration as e: | |
no_more_paths = True | |
nodes = frozenset(get['nodes'].values()) | |
arcs = frozenset({}) | |
for from_node in get['arcs']: | |
for to_node in get['arcs'][from_node]: | |
for arc in get['arcs'][from_node][to_node]: | |
arcs |= frozenset({get['arcs'][from_node][to_node][arc]}) | |
G = DiGraph(nodes, arcs) | |
mfps = MaxFlowProblemState(G, sourceNodeUid=sid, terminalNodeUid=tid, maxFlowProblemStateUid=mfps.maxFlowProblemStateUid) | |
return mfps |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment