Created
May 9, 2014 00:09
-
-
Save rkuykendall/f35e7a37eb1a502257e7 to your computer and use it in GitHub Desktop.
Undirected Graph: Many-to-Many, Reflexive, Non-Directional Relationships in SQLAlchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import create_engine | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy.ext.declarative import declarative_base | |
engine = create_engine('sqlite://', echo=False) | |
session = sessionmaker(engine)() | |
Base = declarative_base() | |
"""An undirected graph example Modified from SQLAlchemy directed | |
graph example.""" | |
from sqlalchemy import MetaData, Table, Column, Integer, ForeignKey | |
from sqlalchemy.orm import mapper, relationship | |
class Node(Base): | |
__tablename__ = 'node' | |
id = Column(Integer, primary_key=True) | |
def __init__(self): | |
pass | |
def add_adjacencies(self, *nodes): | |
for node in nodes: | |
Edge(self, node) | |
return self | |
def add_adjacent(self, node): | |
if node not in self.adjacencies(): | |
self.add_adjacencies(node) | |
def adjacencies(self): | |
all_nodes = [x.lower_node for x in self.higher_edges] | |
all_nodes.extend([x.higher_node for x in self.lower_edges]) | |
return all_nodes | |
class Edge(Base): | |
__tablename__ = 'edge' | |
lower_id = Column(Integer, | |
ForeignKey('node.id'), | |
primary_key=True) | |
higher_id = Column(Integer, | |
ForeignKey('node.id'), | |
primary_key=True) | |
lower_node = relationship(Node, | |
primaryjoin=lower_id==Node.id, | |
backref='lower_edges') | |
higher_node = relationship(Node, | |
primaryjoin=higher_id==Node.id, | |
backref='higher_edges') | |
# here we have lower.id <= higher.id | |
def __init__(self, n1, n2): | |
if n1.id < n2.id: | |
self.lower_node = n1 | |
self.higher_node = n2 | |
else: | |
self.lower_node = n2 | |
self.higher_node = n1 | |
Base.metadata.create_all(engine) | |
n1 = Node() | |
n2 = Node() | |
n3 = Node() | |
n4 = Node() | |
n5 = Node() | |
n6 = Node() | |
n7 = Node() | |
n2.add_adjacencies(n5, n1) | |
n3.add_adjacent(n6) | |
n3.add_adjacent(n6) | |
n7.add_adjacent(n2) | |
n1.add_adjacent(n3) | |
session.add_all([n1, n2, n3, n4, n5, n6, n7]) | |
session.commit() | |
assert set([x.id for x in n3.adjacencies()]) == set([n1.id, n6.id]) | |
assert set([x.id for x in n2.adjacencies()]) == set([n1.id, n5.id, n7.id]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment