Skip to content

Instantly share code, notes, and snippets.

@rkuykendall
Created May 9, 2014 00:09
Show Gist options
  • Save rkuykendall/f35e7a37eb1a502257e7 to your computer and use it in GitHub Desktop.
Save rkuykendall/f35e7a37eb1a502257e7 to your computer and use it in GitHub Desktop.
Undirected Graph: Many-to-Many, Reflexive, Non-Directional Relationships in SQLAlchemy
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine('sqlite://', echo=False)
session = sessionmaker(engine)()
Base = declarative_base()
"""An undirected graph example Modified from SQLAlchemy directed
graph example."""
from sqlalchemy import MetaData, Table, Column, Integer, ForeignKey
from sqlalchemy.orm import mapper, relationship
class Node(Base):
__tablename__ = 'node'
id = Column(Integer, primary_key=True)
def __init__(self):
pass
def add_adjacencies(self, *nodes):
for node in nodes:
Edge(self, node)
return self
def add_adjacent(self, node):
if node not in self.adjacencies():
self.add_adjacencies(node)
def adjacencies(self):
all_nodes = [x.lower_node for x in self.higher_edges]
all_nodes.extend([x.higher_node for x in self.lower_edges])
return all_nodes
class Edge(Base):
__tablename__ = 'edge'
lower_id = Column(Integer,
ForeignKey('node.id'),
primary_key=True)
higher_id = Column(Integer,
ForeignKey('node.id'),
primary_key=True)
lower_node = relationship(Node,
primaryjoin=lower_id==Node.id,
backref='lower_edges')
higher_node = relationship(Node,
primaryjoin=higher_id==Node.id,
backref='higher_edges')
# here we have lower.id <= higher.id
def __init__(self, n1, n2):
if n1.id < n2.id:
self.lower_node = n1
self.higher_node = n2
else:
self.lower_node = n2
self.higher_node = n1
Base.metadata.create_all(engine)
n1 = Node()
n2 = Node()
n3 = Node()
n4 = Node()
n5 = Node()
n6 = Node()
n7 = Node()
n2.add_adjacencies(n5, n1)
n3.add_adjacent(n6)
n3.add_adjacent(n6)
n7.add_adjacent(n2)
n1.add_adjacent(n3)
session.add_all([n1, n2, n3, n4, n5, n6, n7])
session.commit()
assert set([x.id for x in n3.adjacencies()]) == set([n1.id, n6.id])
assert set([x.id for x in n2.adjacencies()]) == set([n1.id, n5.id, n7.id])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment