Last active
August 16, 2024 07:44
-
-
Save dmontagu/7cfcab8ec9f595c58c6ea49b73aea33d to your computer and use it in GitHub Desktop.
Secure OpenAPI (filter based on available scopes)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Callable, Dict, List, Sequence, Set | |
from pydantic.schema import get_model_name_map | |
from starlette.requests import Request | |
from starlette.routing import BaseRoute, Route | |
from fastapi import Depends, FastAPI, routing | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.openapi.models import OpenAPI | |
from fastapi.openapi.utils import get_openapi_path | |
from fastapi.utils import get_flat_models_from_routes, get_model_definitions | |
def use_secure_openapi(app: FastAPI, security_access_dependency: Callable[..., Dict[str, Set[str]]]) -> None: | |
openapi_url = app.openapi_url | |
if not openapi_url: | |
return | |
for route in app.routes: | |
if isinstance(route, Route) and route.path == openapi_url: | |
app.routes.remove(route) | |
def _secure_openapi_route( | |
request: Request, security_access: Dict[str, Set[str]] = Depends(security_access_dependency) | |
) -> Dict[str, Any]: | |
app: FastAPI = request.app | |
return _get_secure_openapi(app, security_access) | |
app.add_api_route(openapi_url, _secure_openapi_route, include_in_schema=False) | |
_secure_openapi_cache = {} | |
def _get_secure_openapi(app: FastAPI, security_access: Dict[str, Set[str]]) -> Dict[str, Any]: | |
cache_key = (id(app), tuple((k, frozenset(v)) for k, v in security_access.items())) | |
cached = _secure_openapi_cache.get(cache_key, None) | |
if cached is None: | |
cached = _build_secure_openapi( | |
title=app.title, | |
version=app.version, | |
openapi_version=app.openapi_version, | |
description=app.description, | |
routes=app.routes, | |
openapi_prefix=app.openapi_prefix, | |
security_access=security_access | |
) | |
_secure_openapi_cache[cache_key] = cached | |
return cached | |
def _build_secure_openapi( | |
*, | |
title: str, | |
version: str, | |
openapi_version: str = "3.0.2", | |
description: str = None, | |
routes: Sequence[BaseRoute], | |
openapi_prefix: str = "", | |
security_access: Dict[str, Set[str]] | |
) -> Dict: | |
info = {"title": title, "version": version} | |
if description: | |
info["description"] = description | |
output = {"openapi": openapi_version, "info": info} | |
components: Dict[str, Dict] = {} | |
paths: Dict[str, Dict] = {} | |
flat_models = get_flat_models_from_routes(routes) | |
model_name_map = get_model_name_map(flat_models) | |
definitions = get_model_definitions( | |
flat_models=flat_models, model_name_map=model_name_map | |
) | |
for route in routes: | |
if isinstance(route, routing.APIRoute): | |
result = get_openapi_path(route=route, model_name_map=model_name_map) | |
if result: | |
path, security_schemes, path_definitions = result | |
if path: | |
if not _is_access_allowed(security_access, path.get("security", [])): | |
continue | |
paths.setdefault(openapi_prefix + route.path_format, {}).update( | |
path | |
) | |
if security_schemes: | |
components.setdefault("securitySchemes", {}).update( | |
security_schemes | |
) | |
if path_definitions: | |
definitions.update(path_definitions) | |
if definitions: | |
components["schemas"] = {k: definitions[k] for k in sorted(definitions)} | |
if components: | |
output["components"] = components | |
output["paths"] = paths | |
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False) | |
def _is_access_allowed(security_access: Dict[str, Set[str]], route_security: List[Dict[str, List[str]]]) -> bool: | |
for operation_security in route_security: | |
for name, scopes in operation_security.items(): | |
accessible_scopes = security_access.get(name, set()) | |
if any(scope not in accessible_scopes for scope in scopes): | |
return False | |
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime, timedelta | |
from typing import Dict, List, Set | |
import jwt | |
from jwt import PyJWTError | |
from passlib.context import CryptContext | |
from pydantic import BaseModel, ValidationError | |
from starlette.status import HTTP_401_UNAUTHORIZED | |
from fastapi import Depends, FastAPI, HTTPException, Security | |
from fastapi.security import ( | |
OAuth2PasswordBearer, | |
SecurityScopes, | |
) | |
from secure_openapi import use_secure_openapi | |
# to get a string like this run: | |
# openssl rand -hex 32 | |
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
fake_users_db = { | |
"johndoe": { | |
"username": "johndoe", | |
"full_name": "John Doe", | |
"email": "[email protected]", | |
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", | |
"disabled": False, | |
}, | |
"alice": { | |
"username": "alice", | |
"full_name": "Alice Chains", | |
"email": "[email protected]", | |
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", | |
"disabled": True, | |
}, | |
} | |
class Token(BaseModel): | |
access_token: str | |
token_type: str | |
class TokenData(BaseModel): | |
username: str = None | |
scopes: List[str] = [] | |
class User(BaseModel): | |
username: str | |
email: str = None | |
full_name: str = None | |
disabled: bool = None | |
class UserInDB(User): | |
hashed_password: str | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
oauth2_scheme = OAuth2PasswordBearer( | |
tokenUrl="/token", | |
scopes={"me": "Read information about the current user.", "items": "Read items."}, | |
) | |
app = FastAPI() | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
def get_user(db, username: str): | |
if username in db: | |
user_dict = db[username] | |
return UserInDB(**user_dict) | |
def authenticate_user(fake_db, username: str, password: str): | |
user = get_user(fake_db, username) | |
if not user: | |
return False | |
if not verify_password(password, user.hashed_password): | |
return False | |
return user | |
def create_access_token(*, data: dict, expires_delta: timedelta = None): | |
to_encode = data.copy() | |
if expires_delta: | |
expire = datetime.utcnow() + expires_delta | |
else: | |
expire = datetime.utcnow() + timedelta(minutes=15) | |
to_encode.update({"exp": expire}) | |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
return encoded_jwt | |
async def get_current_user( | |
security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme) | |
): | |
if security_scopes.scopes: | |
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' | |
else: | |
authenticate_value = f"Bearer" | |
credentials_exception = HTTPException( | |
status_code=HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": authenticate_value}, | |
) | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
token_scopes = payload.get("scopes", []) | |
token_data = TokenData(scopes=token_scopes, username=username) | |
except (PyJWTError, ValidationError): | |
raise credentials_exception | |
user = get_user(fake_users_db, username=token_data.username) | |
if user is None: | |
raise credentials_exception | |
for scope in security_scopes.scopes: | |
if scope not in token_data.scopes: | |
raise HTTPException( | |
status_code=HTTP_401_UNAUTHORIZED, | |
detail="Not enough permissions", | |
headers={"WWW-Authenticate": authenticate_value}, | |
) | |
return user | |
async def get_current_active_user( | |
current_user: User = Security(get_current_user, scopes=["me"]) | |
): | |
if current_user.disabled: | |
raise HTTPException(status_code=400, detail="Inactive user") | |
return current_user | |
@app.get("/users/me/items/") | |
async def read_own_items( | |
current_user: User = Security(get_current_active_user, scopes=["items"]) | |
): | |
return [{"item_id": "Foo", "owner": current_user.username}] | |
def get_token_data(token: str = Depends(oauth2_scheme)) -> TokenData: | |
authenticate_value = f"Bearer" | |
credentials_exception = HTTPException( | |
status_code=HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": authenticate_value}, | |
) | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
token_scopes = payload.get("scopes", []) | |
token_data = TokenData(scopes=token_scopes, username=username) | |
except (PyJWTError, ValidationError): | |
raise credentials_exception | |
return token_data | |
def get_security_access(token_data: TokenData = Depends(get_token_data)) -> Dict[str, Set[str]]: | |
return {"OAuth2PasswordBearer": set(token_data.scopes)} | |
use_secure_openapi(app, get_security_access) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment