Skip to content

Instantly share code, notes, and snippets.

@gaardhus
Last active August 16, 2024 13:12
Show Gist options
  • Save gaardhus/1b2f703c1c32e2cac1caa909c386b403 to your computer and use it in GitHub Desktop.
Save gaardhus/1b2f703c1c32e2cac1caa909c386b403 to your computer and use it in GitHub Desktop.
Added sync client as well
from __future__ import annotations
import inspect
from functools import wraps
from typing import List, Optional, Type
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.ext.asyncio.session import AsyncSession
class BaseClient:
def __init__(self, connection_string: str):
self.engine = create_engine(connection_string)
self.models: list[Type[SQLModel]] = []
class Client(BaseClient):
def create_db_and_tables(self):
"""Creates the db and all the required tables based on the sqlmodels"""
SQLModel.metadata.create_all(self.engine)
class AsyncClient(BaseClient):
def __init__(self, connection_string: str):
super().__init__(connection_string)
self.async_engine = create_async_engine(connection_string)
async def create_db_and_tables(self):
"""Creates the db and all the required tables based on the sqlmodels"""
async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
def create_method_with_model_signature(func, model: Type[SQLModel]):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
# Get the original signature
sig = inspect.signature(func)
# Get the model's fields
model_fields = model.__fields__
# Create new parameters based on the model's fields
new_params = [
inspect.Parameter(
"self",
inspect.Parameter.POSITIONAL_ONLY,
)
]
for field_name, field in model_fields.items():
if field.is_required():
param = inspect.Parameter(
field_name,
inspect.Parameter.KEYWORD_ONLY,
annotation=field.annotation,
)
else:
param = inspect.Parameter(
field_name,
inspect.Parameter.KEYWORD_ONLY,
annotation=field.annotation,
default=field.default,
)
new_params.append(param)
wrapper.__signature__ = sig.replace(parameters=new_params)
return wrapper
def generate_async_client_class(db_models: List[Type[SQLModel]]) -> Type[AsyncClient]:
for model in db_models:
create_async_methods(AsyncClient, model)
return AsyncClient
def generate_sync_client_class(db_models: List[Type[SQLModel]]) -> Type[Client]:
for model in db_models:
create_sync_methods(Client, model)
return Client
def create_expression_from_kwargs(model, **kwargs):
for key, value in kwargs.items():
if isinstance(value, tuple):
yield value[1](model.__dict__[key], value[0])
elif isinstance(value, list):
yield model.__dict__[key].in_(value)
else:
yield model.__dict__[key] == value
def create_async_methods(AsyncClient: Type[AsyncClient], model: Type[SQLModel]):
model_name = model.__name__
lower_model_name = model_name.lower()
# Generate create method
def create_create_method(model):
async def create_method(self: AsyncClient, *args, **kwargs) -> model:
async with AsyncSession(self.async_engine) as session:
db_item = model(*args, **kwargs)
session.add(db_item)
await session.commit()
await session.refresh(db_item)
return db_item
return create_method_with_model_signature(create_method, model)
setattr(AsyncClient, f"create_{lower_model_name}", create_create_method(model))
# Generate get method
def create_get_method(model):
async def get_method(self: AsyncClient, **kwargs) -> Optional[model]:
async with AsyncSession(self.async_engine) as session:
statement = select(model).where(
*create_expression_from_kwargs(model, **kwargs)
)
result = await session.exec(statement)
return result.one()
return create_method_with_model_signature(get_method, model)
setattr(AsyncClient, f"get_{lower_model_name}", create_get_method(model))
# Generate get_all method
def create_get_all_method(model):
async def get_all_method(self: AsyncClient) -> List[model]:
async with AsyncSession(self.async_engine) as session:
statement = select(model)
result = await session.exec(statement)
return result.all()
return create_method_with_model_signature(get_all_method, model)
setattr(AsyncClient, f"get_all_{lower_model_name}s", create_get_all_method(model))
# Generate update method
def create_update_method(model):
async def update_method(self: AsyncClient, **kwargs) -> Optional[model]:
async with AsyncSession(self.async_engine) as session:
statement = select(model).where(
*create_expression_from_kwargs(model, **kwargs)
)
result = await session.exec(statement)
db_item = result.one()
if db_item:
for key, value in kwargs.items():
setattr(db_item, key, value)
await session.commit()
await session.refresh(db_item)
return db_item
return create_method_with_model_signature(update_method, model)
setattr(AsyncClient, f"update_{lower_model_name}", create_update_method(model))
# Generate delete method
def create_delete_method(model):
async def delete_method(self: AsyncClient, **kwargs) -> bool:
async with AsyncSession(self.async_engine) as session:
statement = select(model).where(
*create_expression_from_kwargs(model, **kwargs)
)
result = await session.exec(statement)
db_item = result.one()
if db_item:
await session.delete(db_item)
await session.commit()
return True
return False
return create_method_with_model_signature(delete_method, model)
setattr(AsyncClient, f"delete_{lower_model_name}", create_delete_method(model))
def create_sync_methods(Client: Type[Client], model: Type[SQLModel]):
model_name = model.__name__
lower_model_name = model_name.lower()
# Generate get method
def create_get_method(model):
def get_method(self: Client, **kwargs) -> Optional[model]:
with Session(self.engine) as session:
statement = select(model).where(
*create_expression_from_kwargs(model, **kwargs)
)
result = session.exec(statement)
return result.one_or_none()
return create_method_with_model_signature(get_method, model)
setattr(Client, f"get_{lower_model_name}", create_get_method(model))
# Generate get_all method
def create_get_all_method(model):
def get_all_method(self: Client) -> List[model]:
with Session(self.engine) as session:
statement = select(model)
result = session.exec(statement).all()
return result
return create_method_with_model_signature(get_all_method, model)
setattr(Client, f"get_all_{lower_model_name}s", create_get_all_method(model))
# Generate create method
def create_create_method(model):
def create_method(self: Client, **kwargs) -> model:
with Session(self.engine) as session:
db_item = model(**kwargs)
session.add(db_item)
session.commit()
session.refresh(db_item)
return db_item
return create_method_with_model_signature(create_method, model)
setattr(Client, f"create_{lower_model_name}", create_create_method(model))
# Generate update method
def create_update_method(model):
def update_method(self: Client, **kwargs) -> Optional[model]:
with Session(self.engine) as session:
statement = select(model).where(
*create_expression_from_kwargs(model, **kwargs)
)
result = session.exec(statement)
db_item = result.one_or_none()
if db_item:
for key, value in kwargs.items():
setattr(db_item, key, value)
session.commit()
session.refresh(db_item)
return db_item
return create_method_with_model_signature(update_method, model)
setattr(Client, f"update_{lower_model_name}", create_update_method(model))
# Generate delete method
def create_delete_method(model):
def delete_method(self: Client, id: int) -> bool:
with Session(self.engine) as session:
statement = select(model).where(model.id == id)
result = session.exec(statement)
db_item = result.one_or_none()
if db_item:
session.delete(db_item)
session.commit()
return True
return False
return create_method_with_model_signature(delete_method, model)
setattr(Client, f"delete_{lower_model_name}", create_delete_method(model))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment