Skip to content

Instantly share code, notes, and snippets.

@dmontagu
Last active August 16, 2024 07:44
Show Gist options
  • Save dmontagu/7cfcab8ec9f595c58c6ea49b73aea33d to your computer and use it in GitHub Desktop.
Save dmontagu/7cfcab8ec9f595c58c6ea49b73aea33d to your computer and use it in GitHub Desktop.
Secure OpenAPI (filter based on available scopes)
from typing import Any, Callable, Dict, List, Sequence, Set
from pydantic.schema import get_model_name_map
from starlette.requests import Request
from starlette.routing import BaseRoute, Route
from fastapi import Depends, FastAPI, routing
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import OpenAPI
from fastapi.openapi.utils import get_openapi_path
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
def use_secure_openapi(app: FastAPI, security_access_dependency: Callable[..., Dict[str, Set[str]]]) -> None:
openapi_url = app.openapi_url
if not openapi_url:
return
for route in app.routes:
if isinstance(route, Route) and route.path == openapi_url:
app.routes.remove(route)
def _secure_openapi_route(
request: Request, security_access: Dict[str, Set[str]] = Depends(security_access_dependency)
) -> Dict[str, Any]:
app: FastAPI = request.app
return _get_secure_openapi(app, security_access)
app.add_api_route(openapi_url, _secure_openapi_route, include_in_schema=False)
_secure_openapi_cache = {}
def _get_secure_openapi(app: FastAPI, security_access: Dict[str, Set[str]]) -> Dict[str, Any]:
cache_key = (id(app), tuple((k, frozenset(v)) for k, v in security_access.items()))
cached = _secure_openapi_cache.get(cache_key, None)
if cached is None:
cached = _build_secure_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
description=app.description,
routes=app.routes,
openapi_prefix=app.openapi_prefix,
security_access=security_access
)
_secure_openapi_cache[cache_key] = cached
return cached
def _build_secure_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.0.2",
description: str = None,
routes: Sequence[BaseRoute],
openapi_prefix: str = "",
security_access: Dict[str, Set[str]]
) -> Dict:
info = {"title": title, "version": version}
if description:
info["description"] = description
output = {"openapi": openapi_version, "info": info}
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
if isinstance(route, routing.APIRoute):
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, path_definitions = result
if path:
if not _is_access_allowed(security_access, path.get("security", [])):
continue
paths.setdefault(openapi_prefix + route.path_format, {}).update(
path
)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
def _is_access_allowed(security_access: Dict[str, Set[str]], route_security: List[Dict[str, List[str]]]) -> bool:
for operation_security in route_security:
for name, scopes in operation_security.items():
accessible_scopes = security_access.get(name, set())
if any(scope not in accessible_scopes for scope in scopes):
return False
return True
from datetime import datetime, timedelta
from typing import Dict, List, Set
import jwt
from jwt import PyJWTError
from passlib.context import CryptContext
from pydantic import BaseModel, ValidationError
from starlette.status import HTTP_401_UNAUTHORIZED
from fastapi import Depends, FastAPI, HTTPException, Security
from fastapi.security import (
OAuth2PasswordBearer,
SecurityScopes,
)
from secure_openapi import use_secure_openapi
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "[email protected]",
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
"disabled": False,
},
"alice": {
"username": "alice",
"full_name": "Alice Chains",
"email": "[email protected]",
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm",
"disabled": True,
},
}
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str = None
scopes: List[str] = []
class User(BaseModel):
username: str
email: str = None
full_name: str = None
disabled: bool = None
class UserInDB(User):
hashed_password: str
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="/token",
scopes={"me": "Read information about the current user.", "items": "Read items."},
)
app = FastAPI()
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(
security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme)
):
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = f"Bearer"
credentials_exception = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (PyJWTError, ValidationError):
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return user
async def get_current_active_user(
current_user: User = Security(get_current_user, scopes=["me"])
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
@app.get("/users/me/items/")
async def read_own_items(
current_user: User = Security(get_current_active_user, scopes=["items"])
):
return [{"item_id": "Foo", "owner": current_user.username}]
def get_token_data(token: str = Depends(oauth2_scheme)) -> TokenData:
authenticate_value = f"Bearer"
credentials_exception = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (PyJWTError, ValidationError):
raise credentials_exception
return token_data
def get_security_access(token_data: TokenData = Depends(get_token_data)) -> Dict[str, Set[str]]:
return {"OAuth2PasswordBearer": set(token_data.scopes)}
use_secure_openapi(app, get_security_access)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment