Created
February 23, 2023 19:10
-
-
Save guibeira/547b12eb88f1133ca493093e5d3e70b0 to your computer and use it in GitHub Desktop.
Simple solution for create batch and query nested objects
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from uuid import uuid4 | |
from collections import defaultdict | |
import uuid | |
from pydantic import BaseModel, Field | |
from pydantic.types import UUID4 | |
from sqlalchemy import Column, MetaData, Table, create_engine, text, ForeignKey | |
from sqlalchemy.sql import select | |
from sqlalchemy.dialects import postgresql | |
DATABASE_NAME = "sqlalchemy_test" | |
DATABASE_HOST = f"postgresql+psycopg2://postgres:admin@localhost:5432/" | |
DATABASE_URL = DATABASE_HOST + DATABASE_NAME | |
## Drop and recreate the tables each run for a quick development loop | |
psql_engine = create_engine(DATABASE_HOST) | |
with psql_engine.connect().execution_options( | |
isolation_level="AUTOCOMMIT" | |
) as connection: | |
connection.execute(text(f"DROP DATABASE IF EXISTS {DATABASE_NAME}")) | |
connection.execute(text(f"CREATE DATABASE {DATABASE_NAME}")) | |
engine = create_engine(DATABASE_URL) | |
## Create a few very simple pydantic models | |
class Child(BaseModel): | |
id: UUID4 = Field(default_factory=uuid4) | |
class Parent(BaseModel): | |
id: UUID4 = Field(default_factory=uuid4) | |
children: list[Child] | |
## Write the schema definitions | |
metadata_obj = MetaData() | |
parent_table = Table( | |
"parent", | |
metadata_obj, | |
Column("id", postgresql.UUID(as_uuid=True), primary_key=True), | |
) | |
child_table = Table( | |
"child", | |
metadata_obj, | |
Column("id", postgresql.UUID(as_uuid=True), primary_key=True), | |
Column( | |
"parent_id", | |
postgresql.UUID(as_uuid=True), | |
ForeignKey("parent.id"), | |
nullable=False, | |
), | |
Column("created_at", postgresql.TIMESTAMP, server_default=text("now()")), | |
) | |
# Other required tables | |
metadata_obj.create_all(engine) | |
class ParentSocket: | |
def create_many(self, objs: list[Parent], limit=100) -> int: | |
"""Bulk create a list of parent objects and returns the number inserted.""" | |
with engine.begin() as conn: | |
# Insert parent records | |
parent_records = [{"id": obj.id} for obj in objs] | |
parent_inserted = conn.execute(parent_table.insert().values(parent_records)) | |
total_inserted = parent_inserted.rowcount | |
# Insert child records | |
child_records = [] | |
for obj in objs: | |
for child in obj.children: | |
child_records.append({"id": child.id, "parent_id": obj.id}) | |
# Split the child records into batches of using the limit parameter | |
child_records_batches = [ | |
child_records[i : i + limit] | |
for i in range(0, len(child_records), limit) | |
] | |
# Insert each batch of child records into the database | |
for batch in child_records_batches: | |
children_inserted = conn.execute(child_table.insert().values(batch)) | |
total_inserted += children_inserted.rowcount | |
return total_inserted | |
def query(self, id: UUID4 | list[UUID4]) -> list[Parent]: | |
"""Query either a single id or multiple ids from the Parent table | |
and return pydantic objects.""" | |
if isinstance(id, uuid.UUID): | |
id = [id] | |
query = ( | |
select(parent_table.c.id, child_table.c.id.label("child_id")) | |
.select_from(parent_table.join(child_table)) | |
.where(parent_table.c.id.in_(id)) | |
.order_by(parent_table.c.id, child_table.c.id) | |
) | |
with engine.connect() as conn: | |
result = conn.execute(query) | |
# Group the children by parent id | |
children_by_parent_id = defaultdict(list) | |
for row in result: | |
children_by_parent_id[row[0]].append(row[1]) | |
# Map the parent id to parent object with children | |
parents = [] | |
for parent_id, child_ids in children_by_parent_id.items(): | |
parent_obj = Parent( | |
id=parent_id, | |
children=[Child(id=child_id) for child_id in child_ids], | |
) | |
parents.append(parent_obj) | |
return parents | |
parent_socket = ParentSocket() | |
# Create several children | |
parent1 = Parent(children=[Child(), Child()]) | |
parent2 = Parent(children=[Child()]) | |
parent3 = Parent(children=[Child(), Child(), Child()]) | |
assert parent_socket.create_many([parent1, parent2, parent3]) == 9 | |
assert parent_socket.query(id=parent1.id)[0].id == parent1.id | |
assert len(parent_socket.query(id=[parent1.id, parent2.id])) == 2 | |
# Write any additional tests as desired! | |
assert len(parent_socket.query(id=parent1.id)[0].children) == 2 | |
assert len(parent_socket.query(id=parent2.id)[0].children) == 1 | |
assert len(parent_socket.query(id=parent3.id)[0].children) == 3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment