Skip to content

Instantly share code, notes, and snippets.

@jkimbo
Last active April 8, 2023 15:05
Show Gist options
  • Save jkimbo/3118b4bde65e67540d6cd45b9677cab4 to your computer and use it in GitHub Desktop.
Save jkimbo/3118b4bde65e67540d6cd45b9677cab4 to your computer and use it in GitHub Desktop.
Strawberry Relay API

Strawberry Relay API

Proposal for the Strawberry Relay API.

Node

from strawberry import relay

@strawberry.type
class Fruit(relay.Node):
    id: relay.NodeID[int]
    name: str
    weight: float

    @classmethod
    def resolve_node(cls, info: Info, id: int):
        # id will be the decoded relay ID

        instance = get_fruit(id)
        return Fruit(
            id=instance["id"],
            name=instance["name"],
            weight=instance["weight"],
        )


# Assume we have a method to fetch a fruit:
def get_fruit(id: int) -> Dict[str, Any]:
  ...

Now we can expose it in the schema for retrieval:

@strawberry.type
class Query:
    node: relay.Node = relay.NodeField(description="Get a Node from a NodeID")

This will generate a schema like this:

scalar NodeID

interface Node {
  id: NodeID!
}

type Fruit implements Node {
  id: NodeID!
  name: String!
  weight: Float!
}

type Query {
  """
  Get a Node from a NodeID
  """
  node(id: NodeID!): Node!
}

Notes:

  • I've renamed GlobalID to NodeID just because I think it's a better name. Don't have a strong opinion on it though so happy to go with GlobalID.

Connections

Given our Fruit type above we can create a connection to paginate through a list of them:

import strawberry
from strawberry import relay

class FruitConnection(relay.ListConnection[Fruit]):
    @staticmethod
    def resolve_node(root, info: Info, **kwargs) -> Fruit:
        return Fruit(
            id=root["id"],
            name=root["name"],
            weight=root["weight"],
        )

@strawberry.type
class Query:
    fruits: FruitConnection = relay.connection(resolver=get_all_fruits)

    # this is just a different way of defining the above connection
    @relay.connection(FruitConnection)
    def fruits(self, info: Info) -> Iterable[Dict[str, Any]]:
        return get_all_fruits()
scalar NodeID

interface Node {
  id: NodeID!
}

type PageInfo {
  hasNextPage: Boolean!
  hasPreviousPage: Boolean!
  startCursor: String
  endCursor: String
}

type Fruit implements Node {
  id: NodeID!
  name: String!
  weight: Float!
}

type FruitConnectionEdge {
  cursor: String!
  node: Fruit!
}

type FruitConnection {
  pageInfo: PageInfo!
  edges: [FruitConnectionEdge!]!
}

type Query {
  fruits(
    before: String = null
    after: String = null
    first: Int = null
    last: Int = null
  ): FruitConnection!
}

Notes:

  • resolve_node by default is the identity function.
  • The key difference from strawberry-graphql/strawberry#2511 is that you have to pass the connection class to relay.connection and it works much more like strawberry.field.

Examples

Connection with custom arguments
import strawberry
from strawberry import relay

class FruitConnection(relay.ListConnection[Fruit]):
    ...

@strawberry.type
class Query:
    @relay.connection(FruitConnection)
    def fruits(self, info: Info, only_in_season: bool) -> Iterable[Dict[str, Any]]:
        return get_all_fruits(only_in_season=only_in_season)
scalar NodeID

interface Node {
  id: NodeID!
}

type PageInfo {
  hasNextPage: Boolean!
  hasPreviousPage: Boolean!
  startCursor: String
  endCursor: String
}

type Fruit implements Node {
  id: NodeID!
  name: String!
  weight: Float!
}

type FruitConnectionEdge {
  cursor: String!
  node: Fruit!
}

type FruitConnection {
  pageInfo: PageInfo!
  edges: [FruitConnectionEdge!]!
}

type Query {
  fruits(
    only_in_season: Boolean!
    before: String = null
    after: String = null
    first: Int = null
    last: Int = null
  ): FruitConnection!
}
Direct use with `relay.ListConnection`
import strawberry
from strawberry import relay

@strawberry.type
class Query:
    @relay.connection(relay.ListConnection[Fruit])
    def fruits(self, info: Info, only_in_season: bool) -> Iterable[Dict[str, Any]]:
        return map(
            lambda f: Fruit(id=f["id"], name=f["name"], weight=f["weight"]),
            get_all_fruits(only_in_season=only_in_season
        )
Extra connection fields
import strawberry
from strawberry import relay

class FruitConnection(relay.ListConnection[Fruit]):
    total_count: int = strawberry.field()

    def resolve_connection(self, info: Info, **kwargs) -> "FruitConnection":
        total_count = get_num_fruits()

        # this calls the resolver and paginates it and calls `resolve_node` on
        # each node to get a list of edges
        edges = self.resolve_edges(info, **kwargs)

        return FruitConnection(
            edges=edges,
            total_count=total_count,
            page_info=self.resolve_page_info(edges, info, **kwargs),
        )

    @staticmethod
    def resolve_node(root, info: Info, **kwargs) -> Fruit:
        return Fruit(
            id=root["id"],
            name=root["name"],
            weight=root["weight"],
        )

@strawberry.type
class Query:
    @relay.connection(FruitConnection)
    def fruits(self, info: Info) -> Iterable[Dict[str, Any]]:
        return get_all_fruits()
scalar NodeID

interface Node {
  id: NodeID!
}

type PageInfo {
  hasNextPage: Boolean!
  hasPreviousPage: Boolean!
  startCursor: String
  endCursor: String
}

type Fruit implements Node {
  id: NodeID!
  name: String!
  weight: Float!
}

type FruitConnectionEdge {
  cursor: String!
  node: Fruit!
}

type FruitConnection {
  pageInfo: PageInfo!
  edges: [FruitConnectionEdge!]!
  totalCount: Int!
}

type Query {
  fruits(
    before: String = null
    after: String = null
    first: Int = null
    last: Int = null
  ): FruitConnection!
}
Custom edge type
import strawberry
from strawberry import relay

class FruitConnectionEdge(relay.Edge[Fruit]):
    is_in_fridge: bool = strawberry.field(description="Flag to mark if the fruit is in the users fridge")

class FruitConnection(relay.ListConnection[Fruit]):
    def resolve_edge(self, value, info: Info, **kwargs) -> FruitConnectionEdge:
        return FruitConnectionEdge(
            cursor=self.resolve_cursor(value, info, **kwargs),
            node=self.resolve_node(value, info, **kwargs),
            is_in_fridge=value["is_in_fridge],
        )

    @staticmethod
    def resolve_node(root, info: Info, **kwargs) -> Fruit:
        return Fruit(
            id=root["id"],
            name=root["name"],
            weight=root["weight"],
        )

@strawberry.type
class Query:
    @relay.connection(FruitConnection)
    def fruits(self, info: Info) -> Iterable[Dict[str, Any]]:
        return get_all_fruits(info.current_user)
scalar NodeID

interface Node {
  id: NodeID!
}

type PageInfo {
  hasNextPage: Boolean!
  hasPreviousPage: Boolean!
  startCursor: String
  endCursor: String
}

type Fruit implements Node {
  id: NodeID!
  name: String!
  weight: Float!
}

type FruitConnectionEdge {
  cursor: String!
  node: Fruit!
  isInFridge: Boolean!
}

type FruitConnection {
  pageInfo: PageInfo!
  edges: [FruitConnectionEdge!]!
  totalCount: Int!
}

type Query {
  fruits(
    before: String = null
    after: String = null
    first: Int = null
    last: Int = null
  ): FruitConnection!
}
Simple Django pagination
import strawberry
from strawberry import django

@strawberry.django.type(models.Fruit)
class Fruit(relay.Node):
    id: relay.NodeID[int]
    name: auto
    weight: auto

    # resolve_node gets auto generated because we know the Django model and so
    # how to get one by ID and convert it into the Strawberry type

@strawberry.type
class Query:
    @relay.connection(django.ListConnection[Fruit])
    def fruits(self, info: Info, only_in_season: bool) -> QuerySet[models.Fruit]:
        return models.Fruit.objects.filter(only_in_season=only_in_season)
Django cursor pagination
import strawberry
from strawberry import django

@strawberry.django.type(models.Fruit)
class Fruit(relay.Node):
    id: relay.NodeID[int]
    name: auto
    weight: auto

class FruitConnection(django.CursorConnection[Fruit]):
    def get_sort(self, info: Info, **kwargs) -> Tuple[str]:
        return ("name", "-weight")

@strawberry.type
class Query:
    @relay.connection(FruitConnection)
    def fruits(self, info: Info, only_in_season: bool) -> QuerySet[models.Fruit]:
        return models.Fruit.objects.filter(only_in_season=only_in_season)

Notes:

  • The sort can be dynamic based on parameters to the connection
  • This connection has to operate on a QuerySet (and so should probably validate it at runtime)
  • Probably worth using the libary django-cursor-pagination to implement the actual cursor pagination
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment