Last active
July 22, 2024 08:40
-
-
Save benselme/7278872 to your computer and use it in GitHub Desktop.
SQLAlchemy ordered tree with postgresql recursive CTE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- encoding: utf-8 -*- | |
from sqlalchemy import (Column, Integer, ForeignKey, String, create_engine, | |
literal, null, type_coerce) | |
from sqlalchemy.dialects.postgresql import array, ARRAY | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import Session, aliased | |
Base = declarative_base() | |
class Division(Base): | |
__tablename__ = 'Division' | |
id = Column(Integer, primary_key=True) | |
parent_id = Column(Integer, ForeignKey('Division.id')) | |
label = Column(String(32)) | |
ord = Column(Integer) | |
def test(): | |
engine = create_engine('postgresql://postgres:postgres@' | |
'localhost:5432/playground') | |
session = Session(bind=engine) | |
Base.metadata.bind = engine | |
Base.metadata.create_all() | |
try: | |
session.add_all( | |
( | |
Division(id=0, label='0', ord=0), | |
Division(id=1, label='1', ord=1), | |
Division(id=2, label='2', ord=2), | |
Division(id=3, label='00', ord=0, parent_id=0), | |
Division(id=5, label='01', ord=1, parent_id=0), | |
Division(id=6, label='02', ord=2, parent_id=0), | |
Division(id=7, label='020', ord=0, parent_id=6), | |
Division(id=8, label='021', ord=1, parent_id=6), | |
Division(id=9, label='10', ord=0, parent_id=1), | |
Division(id=10, label='11', ord=1, parent_id=1), | |
) | |
) | |
hierarchy = session.query( | |
Division, literal(0).label('level'), | |
array([Division.ord]).label('breadcrumbs'))\ | |
.filter(Division.parent_id == null())\ | |
.cte(name="hierarchy", recursive=True) | |
hierarchy_alias = aliased(hierarchy, name="d1") | |
div_alias = aliased(Division, name="d2") | |
hierarchy = hierarchy.union_all( | |
session.query( | |
div_alias, | |
(hierarchy_alias.c.level + 1).label("level"), | |
(hierarchy_alias.c.breadcrumbs + | |
array([div_alias.ord])).label("breadcrumbs")) | |
.filter(div_alias.parent_id == hierarchy_alias.c.id)) | |
# This works, but no breadcrumbs field | |
#result = session.query(Division, hierarchy.c.level)\ | |
# .join(hierarchy, Division.id == hierarchy.c.id)\ | |
# .order_by('breadcrumbs').all() | |
# It fails if I add the breadcrumbs field like this: | |
#result = session.query(Division, hierarchy.c.level, | |
# hierarchy.c.breadcrumbs)\ | |
# .join(hierarchy, Division.id == hierarchy.c.id)\ | |
# .order_by('breadcrumbs').all() | |
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/orm/query.py", line 2242, in all | |
# return list(self) | |
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/orm/loading.py", line 78, in instances | |
# rows = util.unique_list(rows, filter_fn) | |
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/util/_collections.py", line 694, in unique_list | |
# return [x for x in seq | |
# File "/home/benselme/dev/projects/wright/.env/lib/python3.3/site-packages/sqlalchemy/util/_collections.py", line 695, in <listcomp> | |
# if hashfunc(x) not in seen | |
#TypeError: unhashable type: 'list' | |
# returns correct result but in the form of tuples only, no Division | |
# entities | |
#result = session.query(hierarchy).order_by('breadcrumbs').all() | |
# Mike Bayer's solution, that actually works | |
result = session.query(Division, | |
hierarchy.c.level, | |
type_coerce(hierarchy.c.breadcrumbs, | |
ARRAY(Integer, as_tuple=True)))\ | |
.select_entity_from(hierarchy).order_by('breadcrumbs').all() | |
print(result) | |
finally: | |
session.rollback() | |
Base.metadata.drop_all() | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment