Last active
August 16, 2024 13:12
-
-
Save gaardhus/1b2f703c1c32e2cac1caa909c386b403 to your computer and use it in GitHub Desktop.
Added sync client as well
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import inspect | |
from functools import wraps | |
from typing import List, Optional, Type | |
from sqlalchemy.ext.asyncio import create_async_engine | |
from sqlmodel import Session, SQLModel, create_engine, select | |
from sqlmodel.ext.asyncio.session import AsyncSession | |
class BaseClient: | |
def __init__(self, connection_string: str): | |
self.engine = create_engine(connection_string) | |
self.models: list[Type[SQLModel]] = [] | |
class Client(BaseClient): | |
def create_db_and_tables(self): | |
"""Creates the db and all the required tables based on the sqlmodels""" | |
SQLModel.metadata.create_all(self.engine) | |
class AsyncClient(BaseClient): | |
def __init__(self, connection_string: str): | |
super().__init__(connection_string) | |
self.async_engine = create_async_engine(connection_string) | |
async def create_db_and_tables(self): | |
"""Creates the db and all the required tables based on the sqlmodels""" | |
async with self.async_engine.begin() as conn: | |
await conn.run_sync(SQLModel.metadata.create_all) | |
def create_method_with_model_signature(func, model: Type[SQLModel]): | |
@wraps(func) | |
def wrapper(self, *args, **kwargs): | |
return func(self, *args, **kwargs) | |
# Get the original signature | |
sig = inspect.signature(func) | |
# Get the model's fields | |
model_fields = model.__fields__ | |
# Create new parameters based on the model's fields | |
new_params = [ | |
inspect.Parameter( | |
"self", | |
inspect.Parameter.POSITIONAL_ONLY, | |
) | |
] | |
for field_name, field in model_fields.items(): | |
if field.is_required(): | |
param = inspect.Parameter( | |
field_name, | |
inspect.Parameter.KEYWORD_ONLY, | |
annotation=field.annotation, | |
) | |
else: | |
param = inspect.Parameter( | |
field_name, | |
inspect.Parameter.KEYWORD_ONLY, | |
annotation=field.annotation, | |
default=field.default, | |
) | |
new_params.append(param) | |
wrapper.__signature__ = sig.replace(parameters=new_params) | |
return wrapper | |
def generate_async_client_class(db_models: List[Type[SQLModel]]) -> Type[AsyncClient]: | |
for model in db_models: | |
create_async_methods(AsyncClient, model) | |
return AsyncClient | |
def generate_sync_client_class(db_models: List[Type[SQLModel]]) -> Type[Client]: | |
for model in db_models: | |
create_sync_methods(Client, model) | |
return Client | |
def create_expression_from_kwargs(model, **kwargs): | |
for key, value in kwargs.items(): | |
if isinstance(value, tuple): | |
yield value[1](model.__dict__[key], value[0]) | |
elif isinstance(value, list): | |
yield model.__dict__[key].in_(value) | |
else: | |
yield model.__dict__[key] == value | |
def create_async_methods(AsyncClient: Type[AsyncClient], model: Type[SQLModel]): | |
model_name = model.__name__ | |
lower_model_name = model_name.lower() | |
# Generate create method | |
def create_create_method(model): | |
async def create_method(self: AsyncClient, *args, **kwargs) -> model: | |
async with AsyncSession(self.async_engine) as session: | |
db_item = model(*args, **kwargs) | |
session.add(db_item) | |
await session.commit() | |
await session.refresh(db_item) | |
return db_item | |
return create_method_with_model_signature(create_method, model) | |
setattr(AsyncClient, f"create_{lower_model_name}", create_create_method(model)) | |
# Generate get method | |
def create_get_method(model): | |
async def get_method(self: AsyncClient, **kwargs) -> Optional[model]: | |
async with AsyncSession(self.async_engine) as session: | |
statement = select(model).where( | |
*create_expression_from_kwargs(model, **kwargs) | |
) | |
result = await session.exec(statement) | |
return result.one() | |
return create_method_with_model_signature(get_method, model) | |
setattr(AsyncClient, f"get_{lower_model_name}", create_get_method(model)) | |
# Generate get_all method | |
def create_get_all_method(model): | |
async def get_all_method(self: AsyncClient) -> List[model]: | |
async with AsyncSession(self.async_engine) as session: | |
statement = select(model) | |
result = await session.exec(statement) | |
return result.all() | |
return create_method_with_model_signature(get_all_method, model) | |
setattr(AsyncClient, f"get_all_{lower_model_name}s", create_get_all_method(model)) | |
# Generate update method | |
def create_update_method(model): | |
async def update_method(self: AsyncClient, **kwargs) -> Optional[model]: | |
async with AsyncSession(self.async_engine) as session: | |
statement = select(model).where( | |
*create_expression_from_kwargs(model, **kwargs) | |
) | |
result = await session.exec(statement) | |
db_item = result.one() | |
if db_item: | |
for key, value in kwargs.items(): | |
setattr(db_item, key, value) | |
await session.commit() | |
await session.refresh(db_item) | |
return db_item | |
return create_method_with_model_signature(update_method, model) | |
setattr(AsyncClient, f"update_{lower_model_name}", create_update_method(model)) | |
# Generate delete method | |
def create_delete_method(model): | |
async def delete_method(self: AsyncClient, **kwargs) -> bool: | |
async with AsyncSession(self.async_engine) as session: | |
statement = select(model).where( | |
*create_expression_from_kwargs(model, **kwargs) | |
) | |
result = await session.exec(statement) | |
db_item = result.one() | |
if db_item: | |
await session.delete(db_item) | |
await session.commit() | |
return True | |
return False | |
return create_method_with_model_signature(delete_method, model) | |
setattr(AsyncClient, f"delete_{lower_model_name}", create_delete_method(model)) | |
def create_sync_methods(Client: Type[Client], model: Type[SQLModel]): | |
model_name = model.__name__ | |
lower_model_name = model_name.lower() | |
# Generate get method | |
def create_get_method(model): | |
def get_method(self: Client, **kwargs) -> Optional[model]: | |
with Session(self.engine) as session: | |
statement = select(model).where( | |
*create_expression_from_kwargs(model, **kwargs) | |
) | |
result = session.exec(statement) | |
return result.one_or_none() | |
return create_method_with_model_signature(get_method, model) | |
setattr(Client, f"get_{lower_model_name}", create_get_method(model)) | |
# Generate get_all method | |
def create_get_all_method(model): | |
def get_all_method(self: Client) -> List[model]: | |
with Session(self.engine) as session: | |
statement = select(model) | |
result = session.exec(statement).all() | |
return result | |
return create_method_with_model_signature(get_all_method, model) | |
setattr(Client, f"get_all_{lower_model_name}s", create_get_all_method(model)) | |
# Generate create method | |
def create_create_method(model): | |
def create_method(self: Client, **kwargs) -> model: | |
with Session(self.engine) as session: | |
db_item = model(**kwargs) | |
session.add(db_item) | |
session.commit() | |
session.refresh(db_item) | |
return db_item | |
return create_method_with_model_signature(create_method, model) | |
setattr(Client, f"create_{lower_model_name}", create_create_method(model)) | |
# Generate update method | |
def create_update_method(model): | |
def update_method(self: Client, **kwargs) -> Optional[model]: | |
with Session(self.engine) as session: | |
statement = select(model).where( | |
*create_expression_from_kwargs(model, **kwargs) | |
) | |
result = session.exec(statement) | |
db_item = result.one_or_none() | |
if db_item: | |
for key, value in kwargs.items(): | |
setattr(db_item, key, value) | |
session.commit() | |
session.refresh(db_item) | |
return db_item | |
return create_method_with_model_signature(update_method, model) | |
setattr(Client, f"update_{lower_model_name}", create_update_method(model)) | |
# Generate delete method | |
def create_delete_method(model): | |
def delete_method(self: Client, id: int) -> bool: | |
with Session(self.engine) as session: | |
statement = select(model).where(model.id == id) | |
result = session.exec(statement) | |
db_item = result.one_or_none() | |
if db_item: | |
session.delete(db_item) | |
session.commit() | |
return True | |
return False | |
return create_method_with_model_signature(delete_method, model) | |
setattr(Client, f"delete_{lower_model_name}", create_delete_method(model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment