Last active
December 28, 2017 15:19
-
-
Save pawl/df5ba8923d9929dd1f4fc4e683eced40 to your computer and use it in GitHub Desktop.
Example of a custom "IN()" relationship loading strategy in sqlalchemy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from sqlalchemy import create_engine, Column, ForeignKey, Integer | |
from sqlalchemy.orm import relationship, scoped_session, sessionmaker | |
from sqlalchemy.orm.attributes import set_committed_value | |
from sqlalchemy.ext.declarative import declarative_base | |
engine = create_engine('mysql://root@localhost/test?charset=utf8mb4', | |
convert_unicode=True, | |
echo=True) | |
session = scoped_session(sessionmaker(autocommit=False, | |
autoflush=False, | |
bind=engine)) | |
Base = declarative_base() | |
Base.query = session.query_property() | |
class Post(Base): | |
__tablename__ = 'posts' | |
id = Column(Integer, primary_key=True) | |
products = relationship('Product', lazy='raise_on_sql', backref='post') | |
class Product(Base): | |
__tablename__ = 'products' | |
id = Column(Integer, primary_key=True) | |
post_id = Column(Integer, ForeignKey('posts.id'), index=True) | |
links = relationship('ProductLink', lazy='raise_on_sql') | |
class ProductLink(Base): | |
__tablename__ = 'links' | |
id = Column(Integer, primary_key=True) | |
product_id = Column(Integer, ForeignKey('products.id'), index=True) | |
#Base.metadata.drop_all(engine) | |
Base.metadata.create_all(engine) | |
# create new rows if database is empty | |
first_result = Post.query.first() | |
if not first_result: | |
for x in range(50): | |
products = [ | |
Product(links=[ProductLink() for link in range(8)]) | |
for y in range(12) | |
] | |
session.add(Post(products=products)) | |
session.commit() | |
def products_loader(posts): | |
post_ids = {post.id for post in posts} | |
if not post_ids: | |
return posts | |
products = session.query(Product).filter( | |
Product.post_id.in_(post_ids) | |
).all() | |
# group products by Post.id | |
products_by_post_id = defaultdict(list) | |
for product in products: | |
products_by_post_id[product.post_id].append(product) | |
# add queried products to the post model | |
for post in posts: | |
post_products = products_by_post_id.get(post.id, []) | |
set_committed_value(post, 'products', post_products) | |
return posts | |
def links_loader(posts): | |
product_ids = set() | |
for post in posts: | |
for product in post.products: | |
product_ids.add(product.id) | |
if not product_ids: | |
return posts | |
links = session.query(ProductLink).filter( | |
ProductLink.product_id.in_(product_ids) | |
).all() | |
# group links by Product.id | |
links_by_product_id = defaultdict(list) | |
for link in links: | |
links_by_product_id[link.product_id].append(link) | |
# add queried links to the Product models | |
for post in posts: | |
for product in post.products: | |
post_links = links_by_product_id.get(product.id, []) | |
set_committed_value(product, 'links', post_links) | |
return posts | |
posts = Post.query.limit(20).all() | |
posts = products_loader(posts) | |
posts = links_loader(posts) | |
for post in posts: | |
for product in post.products: | |
print(product.id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment