Created
September 8, 2024 20:29
-
-
Save ddahan/ae496b168a96ff0f40edd952477114ba to your computer and use it in GitHub Desktop.
FastAPI context-aware shell with model discovery and other auto-imports
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Requires Python 3.12+, IPython, rich | |
import datetime | |
import hashlib | |
import importlib | |
import inspect | |
import json | |
import logging | |
import os | |
import pkgutil | |
import random | |
import re | |
import shutil | |
import sys | |
from collections.abc import Callable | |
from pathlib import Path | |
from typing import Any | |
from fastapi import FastAPI | |
from IPython import embed # type: ignore | |
from rich import box | |
from rich import print as rprint | |
from rich.table import Table | |
from sqlmodel import SQLModel | |
from app.core.config import get_settings # Change this to get your own settings | |
from app.core.database import get_session # Change this to get your own db session | |
type ImportDict = dict[str, Any] | |
MODELS_PATH = "app.models" | |
def detect_models() -> dict[str, type[SQLModel]]: | |
""" | |
Detect all SQLModel classes in MODELS_PATH package. | |
Returns a dict with model names as keys and model classes as values. | |
""" | |
models: dict[str, type[SQLModel]] = {} | |
# Import the models package | |
package = importlib.import_module(MODELS_PATH) | |
package_dir = Path(str(package.__file__)).parent | |
# Iterate over all modules in the package directory | |
for _, module_name, is_pkg in pkgutil.iter_modules([str(package_dir)]): | |
if not is_pkg: | |
# Import the module | |
module = importlib.import_module(f"{MODELS_PATH}.{module_name}") | |
# Inspect the module and find all SQLModel subclasses | |
for name, obj in inspect.getmembers(module, inspect.isclass): | |
if issubclass(obj, SQLModel) and obj is not SQLModel: | |
models[name] = obj | |
return models | |
def shell(): | |
""" | |
Run a context-aware user interactive shell (similar to Django shell_plus). | |
""" | |
# Build dicts of modules to be imported | |
main_imports: ImportDict = { | |
"app": FastAPI(), | |
"session": next(get_session()), # get the session from the generator with next() | |
"settings": get_settings(), | |
} | |
python_imports: ImportDict = { | |
"os": os, | |
"sys": sys, | |
"logging": logging, | |
"json": json, | |
"datetime": datetime, | |
"random": random, | |
"re": re, | |
"hashlib": hashlib, | |
"shutil": shutil, | |
} | |
models_imports: ImportDict = detect_models() | |
# Display imported elements | |
def _print_dict( | |
modules_dict: ImportDict, | |
title: str, | |
show_method: Callable[[Any], str] = lambda _: "", | |
) -> None: | |
if modules_dict: | |
table = Table( | |
title=f"{title.capitalize()} imports", | |
title_justify="left", | |
title_style="bold green underline", | |
show_header=False, | |
box=box.SIMPLE, | |
row_styles=["cyan"], | |
) | |
for name in sorted(modules_dict): | |
obj = modules_dict[name] | |
table.add_row(name, show_method(obj)) | |
rprint(table) | |
print("Starting IPython shell with these preloaded variables...\n") | |
_print_dict( | |
main_imports, | |
"main", | |
show_method=lambda obj: f"{obj.__class__.__module__}.{obj.__class__.__name__}", | |
) | |
_print_dict( | |
python_imports, | |
"python built-in", | |
show_method=lambda obj: f"{obj.__class__.__module__}", | |
) | |
_print_dict(models_imports, "SQLModel", show_method=lambda obj: f"{obj.__module__}") | |
# Start IPython shell with preloaded variables | |
all_imports = main_imports | python_imports | models_imports | |
embed(user_ns=all_imports, colors="Neutral") | |
if __name__ == "__main__": | |
shell() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment