Last active
January 23, 2020 01:11
-
-
Save SatishGodaPearl/ac316535c3c3575799c18f8d4011e303 to your computer and use it in GitHub Desktop.
Node Graph Architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
log = logging.getLogger() | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(levelname)s: %(message)s') | |
handler.setFormatter(formatter) | |
log.addHandler(handler) | |
log.setLevel(logging.INFO) | |
class Number(object): | |
def __init__(self, default=0): | |
self._value = default | |
def getValue(self): | |
return self._value | |
def setValue(self, value): | |
self._value = value | |
def __repr__(self): | |
return "<{0}({1})>".format(self.__class__.__name__, self._value) | |
class NodeFactoryError(ValueError): | |
pass | |
class Graph(object): | |
node_factory = {} | |
@classmethod | |
def register_node(cls, name, node_cls): | |
cls.node_factory[name] = node_cls | |
def __init__(self): | |
self.activeNode = None | |
self.nodes = {} | |
self.plugs = {} | |
self.connections = [] | |
def add_node(self, node_type, values): | |
node_cls = self.node_factory.get(node_type) | |
if node_cls is None: | |
msg = "Invalid node type: {0}".format(node_type) | |
log.debug(msg) | |
raise NodeFactoryError(msg) | |
node = node_cls(*values) | |
self.nodes.setdefault(node, []) | |
return node | |
def add_connection(self, unode, dnode, param): | |
connection = Connection() | |
oplug, iplug = connection.create_from_nodes(unode, dnode, param) | |
inputs = self.plugs.setdefault(oplug, []) | |
inputs.append(iplug) | |
to_nodes = self.nodes[oplug.node] | |
if not iplug.node in to_nodes: | |
to_nodes.append(iplug.node) | |
self.connections.append(connection) | |
return connection | |
def execute_connection(self, connection): | |
""" | |
:param connection: Connection | |
:return: | |
""" | |
log.debug("Executing {0}".format(connection)) | |
connection.execute() | |
if g.activeNode == connection.iplug.node: | |
g.activeNode.execute() | |
log.info(g.activeNode.output.getValue()) | |
log.debug("Executed {0}\n".format(connection)) | |
class Node(object): | |
""" | |
A node in the NodeGraph | |
""" | |
def __init__(self): | |
self._name = "{0}{1}".format(self.__class__.__name__, self.__class__.instance) | |
self.__class__.instance += 1 | |
@property | |
def name(self): | |
return self._name | |
class Add(Node): | |
""" | |
This node adds two numbers | |
""" | |
instance = 0 | |
_cache = { | |
} | |
def __init__(self, number1=0, number2=0): | |
super(Add, self).__init__() | |
self.number1 = Number() | |
self.number2 = Number() | |
self.output = Number() | |
self.number1.setValue(number1) | |
self.number2.setValue(number2) | |
@classmethod | |
def cache(cls, inputs): | |
if inputs in cls._cache: | |
return cls._cache[inputs] | |
log.debug("Cache miss: input {0}".format(inputs)) | |
return None | |
@classmethod | |
def function(cls, number1, number2): | |
return number1 + number2 | |
def execute(self): | |
number1 = self.number1.getValue() | |
number2 = self.number2.getValue() | |
result = self.cache((number1, number2)) | |
if result is None: | |
result = self.function(number1, number2) | |
self._cache[(number1, number2)] = result | |
self.output.setValue(result) | |
def __repr__(self): | |
return "<{0}({1}, {2}) -> {3}>".format(self.__class__.__name__, | |
self.number1.getValue(), | |
self.number2.getValue(), | |
self.output.getValue()) | |
Graph.register_node('Add', Add) | |
class Plug(object): | |
""" | |
A Plug is something that can be connected to/from | |
A Plug is a container for a node and one of its parameter | |
""" | |
__slots__ = ('_node', '_param') | |
def __init__(self): | |
self._node = None | |
self._param = None | |
@property | |
def node(self): | |
return self._node | |
def setNodeParam(self, node, param): | |
""" | |
A Connection object will be calling this method | |
""" | |
self._node = node | |
self._param = param | |
def __repr__(self): | |
return "<{0}({1}, {2}>".format(self.__class__.__name__, | |
self._node, | |
self._param) | |
class InputPlug(Plug): | |
""" | |
A Plug that is connected from an OutputPlug | |
""" | |
def setValue(self, value): | |
""" | |
When the connection that this plug belongs to get's executed | |
the value of the node's parameter that this plug contains is updated | |
""" | |
attr = getattr(self._node, self._param) | |
attr.setValue(value) | |
class OutputPlug(Plug): | |
""" | |
A Plug that connects to another node's Input Plug | |
""" | |
def getValue(self): | |
""" | |
This method is called by the connection that this plug belongs to | |
get's executed | |
""" | |
self._node.execute() | |
return self._node.output.getValue() | |
class Connection(object): | |
""" | |
A connection between the Output Plug of a node and the | |
Input Plug of another node. | |
""" | |
# Cache for plugs that are participating in the connections | |
_cache = { | |
'output': { | |
}, | |
'input': { | |
} | |
} | |
__slots__ = ('_oplug', '_iplug') | |
def __init__(self): | |
self._oplug = None | |
self._iplug = None | |
@property | |
def oplug(self): | |
return self._oplug | |
@property | |
def iplug(self): | |
return self._iplug | |
def create_from_plugs(self, oplug, iplug): | |
""" | |
:param oplug: OutputPlug instance | |
:param iplug: InputPlug instance | |
:return: None | |
""" | |
self._oplug = oplug | |
self._iplug = iplug | |
def create_from_nodes(self, upstream_node, downstream_node, downstream_node_param): | |
""" | |
:param upstream_node: Node | |
:param downstream_node: Node | |
:param downstream_node_param: str | |
:return: | |
""" | |
self.setOutputPlug(upstream_node) | |
self.setInputPlug(downstream_node, downstream_node_param) | |
return self._oplug, self._iplug | |
def setOutputPlug(self, node): | |
cache_plugs = self._cache['output'] | |
args = (node, 'output') | |
plug = cache_plugs.get(args) | |
if not plug: | |
log.debug("Connection cache miss {0}".format(args)) | |
plug = OutputPlug() | |
plug.setNodeParam(*args) | |
cache_plugs[args] = plug | |
log.debug("Added {0} to connection cache".format(plug)) | |
self._oplug = plug | |
def setInputPlug(self, node, param): | |
args = (node, param) | |
cache_plugs = self._cache['input'] | |
plug = cache_plugs.get(args) | |
if not plug: | |
log.debug("Connection cache miss {0}".format(args)) | |
plug = InputPlug() | |
plug.setNodeParam(*args) | |
cache_plugs[args] = plug | |
log.debug("Added {0} to connection cache".format(plug)) | |
self._iplug = plug | |
def execute(self): | |
""" | |
Evaluates the upstream plug and then set the value of | |
the downstream plug. | |
""" | |
value = self._oplug.getValue() | |
self._iplug.setValue(value) | |
def __repr__(self): | |
return "<{0}({1} -> {2})>".format(self.__class__.__name__, | |
self._oplug, | |
self._iplug) | |
log.setLevel(logging.DEBUG) | |
g = Graph() | |
Add0 = g.add_node('Add', (100 ,200)) | |
Add1 = g.add_node('Add', (400, 500)) | |
g.activeNode = Add1 | |
connection1 = g.add_connection(Add0, Add1, 'number1') | |
g.execute_connection(connection1) | |
connection2 = g.add_connection(Add0, Add1, 'number2') | |
g.execute_connection(connection2) | |
Add0.number1.setValue(50) | |
Add0.number2.setValue(-50) | |
g.execute_connection(connection1) | |
g.execute_connection(connection2) | |
g.execute_connection(connection1) | |
g.execute_connection(connection2) | |
""" | |
DEBUG: Connection cache miss (<Add(100, 200) -> 0>, 'output') | |
DEBUG: Added <OutputPlug(<Add(100, 200) -> 0>, output> to connection cache | |
DEBUG: Connection cache miss (<Add(400, 500) -> 0>, 'number1') | |
DEBUG: Added <InputPlug(<Add(400, 500) -> 0>, number1> to connection cache | |
DEBUG: Executing <Connection(<OutputPlug(<Add(100, 200) -> 0>, output> -> <InputPlug(<Add(400, 500) -> 0>, number1>)> | |
DEBUG: Cache miss: input (100, 200) | |
DEBUG: Cache miss: input (300, 500) | |
INFO: 800 | |
DEBUG: Executed <Connection(<OutputPlug(<Add(100, 200) -> 300>, output> -> <InputPlug(<Add(300, 500) -> 800>, number1>)> | |
DEBUG: Connection cache miss (<Add(300, 500) -> 800>, 'number2') | |
DEBUG: Added <InputPlug(<Add(300, 500) -> 800>, number2> to connection cache | |
DEBUG: Executing <Connection(<OutputPlug(<Add(100, 200) -> 300>, output> -> <InputPlug(<Add(300, 500) -> 800>, number2>)> | |
DEBUG: Cache miss: input (300, 300) | |
INFO: 600 | |
DEBUG: Executed <Connection(<OutputPlug(<Add(100, 200) -> 300>, output> -> <InputPlug(<Add(300, 300) -> 600>, number2>)> | |
DEBUG: Executing <Connection(<OutputPlug(<Add(50, -50) -> 300>, output> -> <InputPlug(<Add(300, 300) -> 600>, number1>)> | |
DEBUG: Cache miss: input (50, -50) | |
DEBUG: Cache miss: input (0, 300) | |
INFO: 300 | |
DEBUG: Executed <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 300) -> 300>, number1>)> | |
DEBUG: Executing <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 300) -> 300>, number2>)> | |
DEBUG: Cache miss: input (0, 0) | |
INFO: 0 | |
DEBUG: Executed <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 0) -> 0>, number2>)> | |
DEBUG: Executing <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 0) -> 0>, number1>)> | |
INFO: 0 | |
DEBUG: Executed <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 0) -> 0>, number1>)> | |
DEBUG: Executing <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 0) -> 0>, number2>)> | |
INFO: 0 | |
DEBUG: Executed <Connection(<OutputPlug(<Add(50, -50) -> 0>, output> -> <InputPlug(<Add(0, 0) -> 0>, number2>)> | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment