This gist is a companion to this blog post, you should start there!
Last active
May 28, 2024 17:57
-
-
Save catwell/e47e70b47550ba2fb07d04a41bb8baf0 to your computer and use it in GitHub Desktop.
Avoiding N+1 queries in Strawberry GraphQL with DataLoaders
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# DataLoaders version | |
import logging | |
from collections import defaultdict | |
from functools import cached_property | |
from typing import Any | |
import strawberry | |
from quart_cors import cors | |
from quart_db import QuartDB | |
from strawberry.dataloader import DataLoader | |
from strawberry.quart.views import GraphQLView as QuartGraphQLView | |
from strawberry.types import Info | |
from quart import Quart, Request, Response | |
app = Quart("sfdl") | |
app.logger.setLevel(logging.INFO) | |
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl" | |
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False | |
db = QuartDB(app) | |
cors(app, allow_origin="*", allow_methods=["GET", "POST"]) | |
@strawberry.type | |
class Song: | |
id: int | |
name: str | |
album_id: int | |
@strawberry.type | |
class Album: | |
id: int | |
name: str | |
band_id: int | |
@strawberry.field | |
async def songs(self, info: Info) -> list[Song]: | |
dl = info.context["dataloaders"].songs_for_albums | |
return await dl.load(self.id) | |
@strawberry.type | |
class Band: | |
id: int | |
name: str | |
@strawberry.field | |
async def albums(self, info: Info) -> list[Album]: | |
dl = info.context["dataloaders"].albums_for_bands | |
return await dl.load(self.id) | |
@strawberry.type | |
class Query: | |
@strawberry.field | |
async def bands(self) -> list[Band]: | |
query = """ | |
SELECT id, name | |
FROM bands | |
""" | |
async with db.connection() as cnx: | |
result = await cnx.fetch_all(query) | |
bands = [Band(**row) for row in result] | |
app.logger.info(f"Got {len(bands)} bands.") | |
return bands | |
class DataLoaders: | |
@staticmethod | |
async def load_songs_for_albums(keys: list[int]) -> list[list[Song]]: | |
query = """ | |
SELECT id, name, album_id | |
FROM songs | |
WHERE album_id = ANY(:keys) | |
""" | |
async with db.connection() as cnx: | |
result = await cnx.fetch_all(query, {"keys": keys}) | |
songs = [Song(**row) for row in result] | |
app.logger.info(f"Got {len(songs)} songs.") | |
by_key: defaultdict[int, list[Song]] = defaultdict(list) | |
for song in songs: | |
by_key[song.album_id].append(song) | |
return [by_key[k] for k in keys] | |
@staticmethod | |
async def load_albums_for_bands(keys: list[int]) -> list[list[Album]]: | |
query = """ | |
SELECT id, name, band_id | |
FROM albums | |
WHERE band_id = ANY(:keys) | |
""" | |
async with db.connection() as cnx: | |
result = await cnx.fetch_all(query, {"keys": keys}) | |
albums = [Album(**row) for row in result] | |
app.logger.info(f"Got {len(albums)} albums.") | |
by_key: defaultdict[int, list[Album]] = defaultdict(list) | |
for album in albums: | |
by_key[album.band_id].append(album) | |
return [by_key[k] for k in keys] | |
@cached_property | |
def songs_for_albums(self) -> DataLoader[int, list[Song]]: | |
return DataLoader(self.load_songs_for_albums) | |
@cached_property | |
def albums_for_bands(self) -> DataLoader[int, list[Album]]: | |
return DataLoader(self.load_albums_for_bands) | |
class GraphQLView(QuartGraphQLView): | |
async def get_context(self, request: Request, response: Response) -> dict[str, Any]: | |
return {"request": request, "response": response, "dataloaders": DataLoaders()} | |
view = GraphQLView.as_view( | |
"graphql_view", | |
schema=strawberry.Schema(query=Query), | |
graphql_ide="graphiql", | |
) | |
app.add_url_rule("/", view_func=view) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from quart_db import Connection | |
async def create_schema(cnx: Connection) -> None: | |
await cnx.execute( | |
""" | |
CREATE TABLE bands ( | |
id bigint PRIMARY KEY, | |
name text NOT NULL | |
); | |
CREATE TABLE albums ( | |
id bigint PRIMARY KEY, | |
name text NOT NULL, | |
band_id bigint REFERENCES bands(id) | |
); | |
CREATE TABLE songs ( | |
id bigint GENERATED ALWAYS AS IDENTITY PRIMARY KEY, | |
name text NOT NULL, | |
album_id bigint REFERENCES albums(id) | |
); | |
""" | |
) | |
async def populate(cnx: Connection) -> None: | |
await cnx.execute( | |
""" | |
INSERT INTO bands (id, name) VALUES | |
(1, 'Dark Tranquillity'), | |
(2, 'Pineapple Thief'), | |
(3, 'Wintersun'); | |
INSERT INTO albums (id, name, band_id) VALUES | |
(1, 'Haven', 1), | |
(2, 'Fiction', 1), | |
(3, 'Atoma', 1), | |
(4, 'Time I', 3), | |
(5, 'Your Wilderness', 2), | |
(6, 'Versions of the Truth', 2); | |
INSERT INTO songs (name, album_id) VALUES | |
('The Wonders at Your Feet', 1), | |
('Not Built to Last', 1), | |
('Indifferent Suns', 1), | |
('At Loss for Words', 1), | |
('Terminus', 2), | |
('Inside the Particle Storm', 2), | |
('Focus Shift', 2), | |
('Forward Momentum', 3), | |
('Caves and Embers', 3), | |
('When Mountains Fall', 4), | |
('Sons of Winter and Stars', 4), | |
('Land of Snow and Sorrow', 4), | |
('Time', 4), | |
('The Final Thing on My Mind', 5), | |
('Tear You Up', 5); | |
""" | |
) | |
async def migrate(cnx: Connection) -> None: | |
await create_schema(cnx) | |
await populate(cnx) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Naive version | |
import logging | |
from typing import Any | |
import strawberry | |
from quart_cors import cors | |
from quart_db import QuartDB | |
from strawberry.quart.views import GraphQLView as QuartGraphQLView | |
from quart import Quart, Request, Response | |
app = Quart("sfdl") | |
app.logger.setLevel(logging.INFO) | |
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl" | |
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False | |
db = QuartDB(app) | |
cors(app, allow_origin="*", allow_methods=["GET", "POST"]) | |
@strawberry.type | |
class Song: | |
id: int | |
name: str | |
album_id: int | |
@strawberry.type | |
class Album: | |
id: int | |
name: str | |
band_id: int | |
@strawberry.field | |
async def songs(self) -> list[Song]: | |
query = """ | |
SELECT id, name, album_id | |
FROM songs | |
WHERE album_id = :album_id | |
""" | |
async with db.connection() as cnx: | |
result = await cnx.fetch_all(query, {"album_id": self.id}) | |
songs = [Song(**row) for row in result] | |
app.logger.info(f"Got {len(songs)} songs.") | |
return songs | |
@strawberry.type | |
class Band: | |
id: int | |
name: str | |
@strawberry.field | |
async def albums(self) -> list[Album]: | |
query = """ | |
SELECT id, name, band_id | |
FROM albums | |
WHERE band_id = :band_id | |
""" | |
async with db.connection() as cnx: | |
result = await cnx.fetch_all(query, {"band_id": self.id}) | |
albums = [Album(**row) for row in result] | |
app.logger.info(f"Got {len(albums)} albums.") | |
return albums | |
@strawberry.type | |
class Query: | |
@strawberry.field | |
async def bands(self) -> list[Band]: | |
query = """ | |
SELECT id, name | |
FROM bands | |
""" | |
async with db.connection() as cnx: | |
result = await cnx.fetch_all(query) | |
bands = [Band(**row) for row in result] | |
app.logger.info(f"Got {len(bands)} bands.") | |
return bands | |
class GraphQLView(QuartGraphQLView): | |
async def get_context(self, request: Request, response: Response) -> dict[str, Any]: | |
return {"request": request, "response": response} | |
view = GraphQLView.as_view( | |
"graphql_view", | |
schema=strawberry.Schema(query=Query), | |
graphql_ide="graphiql", | |
) | |
app.add_url_rule("/", view_func=view) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment