|
from __future__ import annotations |
|
|
|
import decimal |
|
import logging |
|
import os |
|
import signal |
|
import sys |
|
import threading |
|
import time |
|
from contextlib import contextmanager |
|
from dataclasses import ( |
|
dataclass, |
|
field, |
|
) |
|
from datetime import datetime |
|
from functools import lru_cache |
|
from typing import ( |
|
Any, |
|
Dict, |
|
Iterable, |
|
Generator, |
|
List, |
|
Optional, |
|
Tuple, |
|
) |
|
|
|
import pydantic |
|
import typer |
|
import uvicorn as uvicorn |
|
import yaml |
|
from fastapi import ( |
|
BackgroundTasks, |
|
Depends, |
|
FastAPI, |
|
HTTPException, |
|
Response, |
|
status, |
|
) |
|
from fastapi.responses import PlainTextResponse |
|
from gunicorn.glogging import Logger |
|
from loguru import logger |
|
from pydantic import BaseModel |
|
from pydantic_sqlalchemy import sqlalchemy_to_pydantic |
|
from sqlalchemy import ( |
|
ARRAY, |
|
DECIMAL, |
|
TEXT, |
|
TIMESTAMP, |
|
BigInteger, |
|
Boolean, |
|
CheckConstraint, |
|
Column, |
|
Date, |
|
DateTime, |
|
Enum, |
|
Float, |
|
ForeignKey, |
|
Index, |
|
Integer, |
|
Numeric, |
|
PrimaryKeyConstraint, |
|
String, |
|
Table, |
|
Text, |
|
UniqueConstraint, |
|
and_, |
|
create_engine, |
|
engine, |
|
event, |
|
func, |
|
or_, |
|
) |
|
from sqlalchemy.ext.declarative import declared_attr |
|
from sqlalchemy.orm import ( |
|
Session, |
|
registry, |
|
relationship, |
|
sessionmaker, |
|
) |
|
from sqlalchemy.schema import Index |
|
from starlette.middleware.cors import CORSMiddleware |
|
from uvicorn.workers import UvicornWorker |
|
|
|
import settings |
|
|
|
# from cmath import log |
|
|
|
|
|
class ReloaderThread(threading.Thread): |
|
def __init__(self, worker: UvicornWorker, sleep_interval: float = 1.0): |
|
super().__init__() |
|
self.setDaemon(True) |
|
self._worker = worker |
|
self._interval = sleep_interval |
|
|
|
def run(self) -> None: |
|
while True: |
|
if not self._worker.alive: |
|
os.kill(os.getpid(), signal.SIGINT) |
|
time.sleep(self._interval) |
|
|
|
|
|
class RestartableUvicornWorker(UvicornWorker): |
|
|
|
CONFIG_KWARGS = { |
|
"loop": "uvloop", |
|
"http": "httptools", |
|
# "log_config": yaml.safe_load(open(os.path.join(os.path.dirname(__file__), "logging.yaml"), "r") |
|
} |
|
|
|
def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]): |
|
super().__init__(*args, **kwargs) |
|
self._reloader_thread = ReloaderThread(self) |
|
|
|
def run(self) -> None: |
|
if self.cfg.reload: |
|
self._reloader_thread.start() |
|
super().run() |
|
|
|
|
|
class InterceptHandler(logging.Handler): |
|
""" |
|
Default handler from examples in loguru documentaion. |
|
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging |
|
""" |
|
|
|
def emit(self, record: logging.LogRecord): |
|
# Get corresponding Loguru level if it exists |
|
try: |
|
level = logger.level(record.levelname).name |
|
except ValueError: |
|
level = record.levelno |
|
|
|
# Find caller from where originated the logged message |
|
frame, depth = logging.currentframe(), 1 |
|
# while frame.f_code.co_filename == logging.__file__: |
|
# frame = frame.f_back |
|
# depth += 1 |
|
|
|
logger.opt(depth=depth, exception=record.exc_info).log( |
|
level, record.getMessage() |
|
) |
|
|
|
|
|
class GunicornLogger(Logger): |
|
def setup(self, cfg) -> None: |
|
handler = InterceptHandler() |
|
# handler = logging.StreamHandler(sys.stdout) |
|
handler.setFormatter( |
|
logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s") |
|
) |
|
|
|
# Add log handler to logger and set log level |
|
self.error_log.addHandler(handler) |
|
self.error_log.setLevel(settings.LOG_LEVEL) |
|
self.access_log.addHandler(handler) |
|
self.access_log.setLevel(settings.LOG_LEVEL) |
|
|
|
# Configure logger before gunicorn starts logging |
|
logger.configure(handlers=[{"sink": sys.stdout, "level": settings.LOG_LEVEL}]) |
|
|
|
|
|
@lru_cache() |
|
def get_engine() -> engine.Engine: |
|
return create_engine( |
|
settings.SQLALCHEMY_DATABASE_URL, |
|
# connect_args={"check_same_thread": False}, |
|
echo=True, |
|
pool_pre_ping=True, |
|
) |
|
|
|
|
|
def get_db() -> Generator[Session, None, None]: |
|
# Explicit type because sessionmaker.__call__ stub is Any |
|
session: Session = sessionmaker( |
|
autocommit=False, autoflush=False,expire_on_commit=False, bind=get_engine() |
|
)() |
|
# session = SessionLocal() |
|
try: |
|
yield session |
|
session.commit() |
|
except: |
|
session.rollback() |
|
raise |
|
finally: |
|
session.close() |
|
|
|
|
|
mapper_registry = registry() |
|
|
|
|
|
@dataclass |
|
class SurrogatePK: |
|
__sa_dataclass_metadata_key__ = "sa" |
|
id: int = field( |
|
init=False, |
|
default=None, |
|
metadata={"sa": Column(Integer, primary_key=True, autoincrement=True)}, |
|
) |
|
|
|
|
|
@dataclass |
|
class TimeStampMixin: |
|
__sa_dataclass_metadata_key__ = "sa" |
|
created_at: datetime = field( |
|
default_factory=datetime.now, |
|
metadata={"sa": Column(DateTime, default=datetime.now)}, |
|
) |
|
updated_at: datetime = field( |
|
default_factory=datetime.now, |
|
metadata={ |
|
"sa": Column(DateTime, default=datetime.now, onupdate=datetime.utcnow) |
|
}, |
|
) |
|
|
|
|
|
@mapper_registry.mapped |
|
@dataclass |
|
class User(SurrogatePK, TimeStampMixin): |
|
__tablename__ = "user" |
|
|
|
__sa_dataclass_metadata_key__ = "sa" |
|
title: str = field(default=None, metadata={"sa": Column(String(50))}) |
|
description: str = field(default=None, metadata={"sa": Column(String(50))}) |
|
|
|
|
|
UserPyd = sqlalchemy_to_pydantic(User) |
|
|
|
mapper_registry.metadata.create_all(bind=get_engine()) |
|
# Create the app, database, and stocks table |
|
app = FastAPI(debug=True) |
|
|
|
|
|
@app.exception_handler(Exception) |
|
async def validation_exception_handler(request, exc): |
|
logger.debug(str(exc)) |
|
return PlainTextResponse("Something went wrong", status_code=500) |
|
|
|
|
|
cli = typer.Typer() |
|
|
|
|
|
@cli.command() |
|
def db_init_models(): |
|
|
|
Base = mapper_registry.generate_base() |
|
Base.metadata.drop_all(bind=get_engine()) |
|
Base.metadata.create_all(bind=get_engine()) |
|
print("Done") |
|
|
|
|
|
@cli.command() |
|
def nothing(name: str): |
|
|
|
print("Done") |
|
|
|
|
|
@app.get("/items", response_model=List[UserPyd]) |
|
def read_items(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): |
|
items = db.query(User).offset(skip).limit(limit).all() |
|
|
|
return items |
|
|
|
|
|
@app.get("/users/", response_model=UserPyd, status_code=status.HTTP_201_CREATED) |
|
def create_user(email: str = None, db: Session = Depends(get_db)): |
|
u = User(title="sss") |
|
db.add(u) |
|
db.commit() |
|
|
|
# return {"data":new_post} |
|
|
|
return u |
|
|
|
|
|
if __name__ == "__main__": |
|
cli() |