Skip to content

Instantly share code, notes, and snippets.

@benselme
Last active July 22, 2024 08:40
Show Gist options
  • Save benselme/7278872 to your computer and use it in GitHub Desktop.
Save benselme/7278872 to your computer and use it in GitHub Desktop.
SQLAlchemy ordered tree with postgresql recursive CTE
# -*- encoding: utf-8 -*-
from sqlalchemy import (Column, Integer, ForeignKey, String, create_engine,
literal, null, type_coerce)
from sqlalchemy.dialects.postgresql import array, ARRAY
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, aliased
Base = declarative_base()
class Division(Base):
__tablename__ = 'Division'
id = Column(Integer, primary_key=True)
parent_id = Column(Integer, ForeignKey('Division.id'))
label = Column(String(32))
ord = Column(Integer)
def test():
engine = create_engine('postgresql://postgres:postgres@'
'localhost:5432/playground')
session = Session(bind=engine)
Base.metadata.bind = engine
Base.metadata.create_all()
try:
session.add_all(
(
Division(id=0, label='0', ord=0),
Division(id=1, label='1', ord=1),
Division(id=2, label='2', ord=2),
Division(id=3, label='00', ord=0, parent_id=0),
Division(id=5, label='01', ord=1, parent_id=0),
Division(id=6, label='02', ord=2, parent_id=0),
Division(id=7, label='020', ord=0, parent_id=6),
Division(id=8, label='021', ord=1, parent_id=6),
Division(id=9, label='10', ord=0, parent_id=1),
Division(id=10, label='11', ord=1, parent_id=1),
)
)
hierarchy = session.query(
Division, literal(0).label('level'),
array([Division.ord]).label('breadcrumbs'))\
.filter(Division.parent_id == null())\
.cte(name="hierarchy", recursive=True)
hierarchy_alias = aliased(hierarchy, name="d1")
div_alias = aliased(Division, name="d2")
hierarchy = hierarchy.union_all(
session.query(
div_alias,
(hierarchy_alias.c.level + 1).label("level"),
(hierarchy_alias.c.breadcrumbs +
array([div_alias.ord])).label("breadcrumbs"))
.filter(div_alias.parent_id == hierarchy_alias.c.id))
# This works, but no breadcrumbs field
#result = session.query(Division, hierarchy.c.level)\
# .join(hierarchy, Division.id == hierarchy.c.id)\
# .order_by('breadcrumbs').all()
# It fails if I add the breadcrumbs field like this:
#result = session.query(Division, hierarchy.c.level,
# hierarchy.c.breadcrumbs)\
# .join(hierarchy, Division.id == hierarchy.c.id)\
# .order_by('breadcrumbs').all()
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/orm/query.py", line 2242, in all
# return list(self)
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/orm/loading.py", line 78, in instances
# rows = util.unique_list(rows, filter_fn)
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/util/_collections.py", line 694, in unique_list
# return [x for x in seq
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/util/_collections.py", line 695, in <listcomp>
# if hashfunc(x) not in seen
#TypeError: unhashable type: 'list'
# returns correct result but in the form of tuples only, no Division
# entities
#result = session.query(hierarchy).order_by('breadcrumbs').all()
# Mike Bayer's solution, that actually works
result = session.query(Division,
hierarchy.c.level,
type_coerce(hierarchy.c.breadcrumbs,
ARRAY(Integer, as_tuple=True)))\
.select_entity_from(hierarchy).order_by('breadcrumbs').all()
print(result)
finally:
session.rollback()
Base.metadata.drop_all()
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment