Skip to content

Instantly share code, notes, and snippets.

@haoflynet
Created August 14, 2018 09:33
Show Gist options
  • Save haoflynet/85ec02eb003f7e4a3a53bcdcdb7dc8b1 to your computer and use it in GitHub Desktop.
Save haoflynet/85ec02eb003f7e4a3a53bcdcdb7dc8b1 to your computer and use it in GitHub Desktop.
graphene-sqlalchemy使用示例,教程见https://haofly.net/python-graphql
import graphene
from graphene import String
from graphene_sqlalchemy import SQLAlchemyObjectType, SQLAlchemyConnectionField
from promise import Promise
from promise.dataloader import DataLoader
from sqlalchemy import Column, BigInteger, ForeignKey
from sqlalchemy.orm import relationship
from db import Base, db_session # 这里自己去定义
# 定义Model
class UserModel(Base):
__tablename__ = "users"
id = Column(BigInteger, primary_key=True)
name = Column(String(255))
posts = relationship("PostModel", backref="posts")
class PostModel(Base):
__tablename__ = "posts"
id = Column(BigInteger, primary_key=True)
user_id = Column(BigInteger, ForeignKey("User.id"))
# 定义schema
class Post(SQLAlchemyObjectType):
class Meta:
model = PostModel
class User(SQLAlchemyObjectType):
posts = graphene.List(Post)
def resolve_posts(self, info, **args):
return (
info.context.get("post_data_loader")
.load(self.id) # 即是DataLoader里面的key
.then(lambda response: [response])
)
class Meta:
model = UserModel
# 定义Connection Field
class UsersConnectionField(SQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
super().__init__(type, uuid=String(), *args, **kwargs)
# 重写默认的get_query方法以支持自定义参数
@classmethod
def get_query(cls, model, info, sort=None, **args):
query = super().get_query(model, info, None, **args)
if "limit" in args:
query = query.limit(args["limit"])
if "offset" in args:
query = query.offset(args["offset"])
return query
# 定义DataLoader
class PostsDataLoader(DataLoader):
def batch_load_fn(self, keys):
q = db_session.query(PostModel).filter(PostModel.uuid.in_(keys))
posts = dict([(post.id, post) for post in q.all()])
return Promise.resolve([posts.get(id, None) for id in keys])
class Query(graphene.ObjectType):
users = graphene.List(User, limit=graphene.Int(), offset=graphene.Int())
def resolve_users(self, info, **args):
query = UsersConnectionField.get_query(UserModel, info, None, **args)
return query.all()
schema = graphene.Schema(query=Query)
if __name__ == "__main__":
query = """
query {
users (limit:10, offset:20) {
id,
name,
posts {
mirrorId
}
}
}
"""
result = schema.execute(
query,
context_value={"session": db_session, "post_data_loader": PostsDataLoader()},
)
print(result.errors)
print(result.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment