Skip to content

Instantly share code, notes, and snippets.

@xyb
Forked from ddahan/shell.py
Created September 30, 2024 04:29
Show Gist options
  • Save xyb/93df52b6af3c7f1ac0ce4032f248e9c2 to your computer and use it in GitHub Desktop.
Save xyb/93df52b6af3c7f1ac0ce4032f248e9c2 to your computer and use it in GitHub Desktop.
FastAPI context-aware shell with model discovery and other auto-imports
# Requires Python 3.12+, IPython, rich
import datetime
import hashlib
import importlib
import inspect
import json
import logging
import os
import pkgutil
import random
import re
import shutil
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any
from fastapi import FastAPI
from IPython import embed # type: ignore
from rich import box
from rich import print as rprint
from rich.table import Table
from sqlmodel import SQLModel
from app.core.config import get_settings # Change this to get your own settings
from app.core.database import get_session # Change this to get your own db session
type ImportDict = dict[str, Any]
MODELS_PATH = "app.models"
def detect_models() -> dict[str, type[SQLModel]]:
"""
Detect all SQLModel classes in MODELS_PATH package.
Returns a dict with model names as keys and model classes as values.
"""
models: dict[str, type[SQLModel]] = {}
# Import the models package
package = importlib.import_module(MODELS_PATH)
package_dir = Path(str(package.__file__)).parent
# Iterate over all modules in the package directory
for _, module_name, is_pkg in pkgutil.iter_modules([str(package_dir)]):
if not is_pkg:
# Import the module
module = importlib.import_module(f"{MODELS_PATH}.{module_name}")
# Inspect the module and find all SQLModel subclasses
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, SQLModel) and obj is not SQLModel:
models[name] = obj
return models
def shell():
"""
Run a context-aware user interactive shell (similar to Django shell_plus).
"""
# Build dicts of modules to be imported
main_imports: ImportDict = {
"app": FastAPI(),
"session": next(get_session()), # get the session from the generator with next()
"settings": get_settings(),
}
python_imports: ImportDict = {
"os": os,
"sys": sys,
"logging": logging,
"json": json,
"datetime": datetime,
"random": random,
"re": re,
"hashlib": hashlib,
"shutil": shutil,
}
models_imports: ImportDict = detect_models()
# Display imported elements
def _print_dict(
modules_dict: ImportDict,
title: str,
show_method: Callable[[Any], str] = lambda _: "",
) -> None:
if modules_dict:
table = Table(
title=f"{title.capitalize()} imports",
title_justify="left",
title_style="bold green underline",
show_header=False,
box=box.SIMPLE,
row_styles=["cyan"],
)
for name in sorted(modules_dict):
obj = modules_dict[name]
table.add_row(name, show_method(obj))
rprint(table)
print("Starting IPython shell with these preloaded variables...\n")
_print_dict(
main_imports,
"main",
show_method=lambda obj: f"{obj.__class__.__module__}.{obj.__class__.__name__}",
)
_print_dict(
python_imports,
"python built-in",
show_method=lambda obj: f"{obj.__class__.__module__}",
)
_print_dict(models_imports, "SQLModel", show_method=lambda obj: f"{obj.__module__}")
# Start IPython shell with preloaded variables
all_imports = main_imports | python_imports | models_imports
embed(user_ns=all_imports, colors="Neutral")
if __name__ == "__main__":
shell()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment