Skip to content

Instantly share code, notes, and snippets.

@kergoth
Created February 17, 2020 04:16
Show Gist options
  • Save kergoth/63f3eb7f000a4c4d2da105a03f3f1381 to your computer and use it in GitHub Desktop.
Save kergoth/63f3eb7f000a4c4d2da105a03f3f1381 to your computer and use it in GitHub Desktop.
# TODO: prototype node execution memoization / acceleration via tracking
# function inputs and outputs.
from collections import Counter
from concurrent import futures
class NodeState(object):
blocked, runnable, complete, failed = range(4)
class Event(object):
def __init__(self):
self.handlers = []
def register(self, handler):
self.handlers.append(handler)
def unregister(self, handler):
self.handlers.remove(handler)
def fire(self, *args, **kwargs):
for handler in self.handlers:
handler(*args, **kwargs)
def __len__(self):
return len(self.handlers)
def __contains__(self, handler):
return handler in self.handlers
def __iadd__(self, other):
self.register(other)
return self
def __isub__(self, other):
self.unregister(other)
return self
__call__ = fire
class Graph(dict):
def __init__(self):
self.parents = {}
dict.__init__(self)
def execute(self, node):
print("Executing %s" % node)
def current_state(self, node):
if not self[node]:
return NodeState.runnable
else:
return NodeState.blocked
def add_node(self, node, children=None):
self[node] = children
self.parents[node] = set()
if children:
for child in children:
if child not in self.parents:
self.parents[child] = set()
self.parents[child].add(node)
def remove_node(self, node):
del self[node]
del self.parents[node]
for n in self.parents:
if node in self.parents[n]:
self.parents[n].remove(node)
def add_child(self, node, child):
self[node].add(child)
if child not in self.parents:
self.parents[child] = set()
self.parents[child].add(node)
def get_children(self, node):
return self[node]
def get_parents(self, node):
return self.parents.get(node)
class States(dict):
def __init__(self, keys):
self.value_changed = Event()
dict.__init__(self, ((k, NodeState.blocked) for k in keys))
def __setitem__(self, key, value):
old_value = self.get(key)
if old_value != value:
self.value_changed(key, self.get(key), value)
dict.__setitem__(self, key, value)
class Traverse(object):
def __init__(self, graph, targeted=None):
self.graph = graph
self.state = States(graph)
self.targeted = targeted
self.results = {}
self.state_changed = self.state.value_changed
# self.state_changed += lambda node, f, t: print('state changed', node, f, t)
def prepare_states(self):
for node in self.graph:
self.state[node] = self.graph.current_state(node)
def start(self):
"""Initiate graph traversal"""
with futures.ThreadPoolExecutor(max_workers=5) as executor:
future_to_node = {}
def should_queue_node(node, f, t):
if t != NodeState.runnable:
return
future = executor.submit(self.graph.execute, node)
future_to_node[future] = node
self.state_changed += should_queue_node
self.prepare_states()
while future_to_node:
done, not_done = futures.wait(
future_to_node,
return_when=futures.FIRST_COMPLETED)
for future in done:
node = future_to_node[future]
del future_to_node[future]
if future.exception() is not None:
self.state[node] = NodeState.failed
else:
self.state[node] = NodeState.complete
self.results[node] = future.result()
for p in self.graph.get_parents(node):
if self.state[p] == NodeState.blocked and \
all(self.state[c] == NodeState.complete for c in self.graph.get_children(p)):
self.state[p] = NodeState.runnable
self.state_changed -= should_queue_node
if __name__ == '__main__':
g = Graph()
g.add_node('b')
g.add_node('d')
g.add_node('c', ['d'])
g.add_node('a', ['b', 'c'])
t = Traverse(g)
t.start()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment