-
-
Save kergoth/63f3eb7f000a4c4d2da105a03f3f1381 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# TODO: prototype node execution memoization / acceleration via tracking | |
# function inputs and outputs. | |
from collections import Counter | |
from concurrent import futures | |
class NodeState(object): | |
blocked, runnable, complete, failed = range(4) | |
class Event(object): | |
def __init__(self): | |
self.handlers = [] | |
def register(self, handler): | |
self.handlers.append(handler) | |
def unregister(self, handler): | |
self.handlers.remove(handler) | |
def fire(self, *args, **kwargs): | |
for handler in self.handlers: | |
handler(*args, **kwargs) | |
def __len__(self): | |
return len(self.handlers) | |
def __contains__(self, handler): | |
return handler in self.handlers | |
def __iadd__(self, other): | |
self.register(other) | |
return self | |
def __isub__(self, other): | |
self.unregister(other) | |
return self | |
__call__ = fire | |
class Graph(dict): | |
def __init__(self): | |
self.parents = {} | |
dict.__init__(self) | |
def execute(self, node): | |
print("Executing %s" % node) | |
def current_state(self, node): | |
if not self[node]: | |
return NodeState.runnable | |
else: | |
return NodeState.blocked | |
def add_node(self, node, children=None): | |
self[node] = children | |
self.parents[node] = set() | |
if children: | |
for child in children: | |
if child not in self.parents: | |
self.parents[child] = set() | |
self.parents[child].add(node) | |
def remove_node(self, node): | |
del self[node] | |
del self.parents[node] | |
for n in self.parents: | |
if node in self.parents[n]: | |
self.parents[n].remove(node) | |
def add_child(self, node, child): | |
self[node].add(child) | |
if child not in self.parents: | |
self.parents[child] = set() | |
self.parents[child].add(node) | |
def get_children(self, node): | |
return self[node] | |
def get_parents(self, node): | |
return self.parents.get(node) | |
class States(dict): | |
def __init__(self, keys): | |
self.value_changed = Event() | |
dict.__init__(self, ((k, NodeState.blocked) for k in keys)) | |
def __setitem__(self, key, value): | |
old_value = self.get(key) | |
if old_value != value: | |
self.value_changed(key, self.get(key), value) | |
dict.__setitem__(self, key, value) | |
class Traverse(object): | |
def __init__(self, graph, targeted=None): | |
self.graph = graph | |
self.state = States(graph) | |
self.targeted = targeted | |
self.results = {} | |
self.state_changed = self.state.value_changed | |
# self.state_changed += lambda node, f, t: print('state changed', node, f, t) | |
def prepare_states(self): | |
for node in self.graph: | |
self.state[node] = self.graph.current_state(node) | |
def start(self): | |
"""Initiate graph traversal""" | |
with futures.ThreadPoolExecutor(max_workers=5) as executor: | |
future_to_node = {} | |
def should_queue_node(node, f, t): | |
if t != NodeState.runnable: | |
return | |
future = executor.submit(self.graph.execute, node) | |
future_to_node[future] = node | |
self.state_changed += should_queue_node | |
self.prepare_states() | |
while future_to_node: | |
done, not_done = futures.wait( | |
future_to_node, | |
return_when=futures.FIRST_COMPLETED) | |
for future in done: | |
node = future_to_node[future] | |
del future_to_node[future] | |
if future.exception() is not None: | |
self.state[node] = NodeState.failed | |
else: | |
self.state[node] = NodeState.complete | |
self.results[node] = future.result() | |
for p in self.graph.get_parents(node): | |
if self.state[p] == NodeState.blocked and \ | |
all(self.state[c] == NodeState.complete for c in self.graph.get_children(p)): | |
self.state[p] = NodeState.runnable | |
self.state_changed -= should_queue_node | |
if __name__ == '__main__': | |
g = Graph() | |
g.add_node('b') | |
g.add_node('d') | |
g.add_node('c', ['d']) | |
g.add_node('a', ['b', 'c']) | |
t = Traverse(g) | |
t.start() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment