Last active
December 14, 2023 09:57
-
-
Save WilliamStam/cb9a8b4d52d2cca2d6ca3c4d615a81d0 to your computer and use it in GitHub Desktop.
FastAPI authentication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import typing | |
from fastapi import Depends, HTTPException | |
from starlette.requests import Request | |
from domain.user.model import ProfileModel | |
from exception import HTTPForbiddenException | |
from permissions import permissions as system_permissions | |
from utilities.permission import Permission | |
from .user import CurrentUser, get_token | |
logger = logging.getLogger(__name__) | |
class AuthorizationDependency: | |
def __init__( | |
self, | |
permissions: typing.Union[str, list[str], Permission, list[Permission], None] = None | |
): | |
self.permissions = permissions | |
def get_permissions(self): | |
permissions = self.permissions | |
if self.permissions is None: | |
permissions = [] | |
if isinstance(permissions, Permission): | |
permissions = [permissions] | |
elif isinstance(permissions, str): | |
permissions = [permissions] | |
else: | |
permissions = list(permissions) | |
return permissions | |
async def __call__( | |
self, | |
request: Request, | |
user: CurrentUser | |
) -> ProfileModel: | |
if user.id is None: | |
logger.debug("User not logged in") | |
raise HTTPException(status_code=401) | |
missing_permissions = [] | |
if not user.has_permissions(self.get_permissions()): | |
for permission in self.get_permissions(): | |
if str(permission) not in user.permissions: | |
missing_permissions.append(str(permission)) | |
if missing_permissions: | |
logger.debug(f"User missing permissions '{missing_permissions}'") | |
raise HTTPForbiddenException(status_code=403, permissions=missing_permissions) | |
return user | |
# basicaly whenever you require auth then there must be a user. if you pass in args then the user must also have those perms | |
def Authorization( | |
permissions: typing.Union[Permission, list[Permission], None] = None, | |
token=Depends(get_token), | |
) -> Depends: | |
if isinstance(permissions, Permission): | |
permissions = [permissions] | |
# im just adding each "required" permission to a global list so that i can output them as a "tree" type thing (all permissions are dot notation heirarchial) | |
if isinstance(permissions, list): | |
for perm in permissions: | |
system_permissions.add(perm) | |
return Depends(AuthorizationDependency(permissions)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# hide routes that the user cant "see" | |
@router.get("/openapi.json", include_in_schema=False) | |
async def openapi(request: Request, user: CurrentUser): | |
app = request.app | |
routes = [] | |
for route in request.app.routes: | |
route_parts = [] | |
route_parts.append(str(route)) | |
include_in_schema = True | |
if hasattr(route, "dependencies") and isinstance(route.dependencies, list): | |
for row in route.dependencies: | |
if hasattr(row, "dependency"): | |
if hasattr(row.dependency, "permissions"): | |
permissions = row.dependency.permissions | |
route_parts.append("===") | |
route_parts.append(f"{permissions}") | |
if user.id is None: | |
route_parts.append("User must be authed") | |
include_in_schema = False | |
if not user.has_permissions(permissions): | |
route_parts.append(f"User must have permissions {permissions}") | |
include_in_schema = False | |
# adding in the routes that pass the user check | |
if include_in_schema: | |
routes.append(route) | |
route_parts.append(f" route.include_in_schema: {route.include_in_schema}") | |
openapi_schema = get_openapi( | |
title=app.title, | |
version=app.version, | |
openapi_version=app.openapi_version, | |
description=app.description, | |
terms_of_service=app.terms_of_service, | |
contact=app.contact, | |
license_info=app.license_info, | |
routes=routes, | |
tags=app.openapi_tags, | |
servers=app.servers | |
) | |
return JSONResponse(openapi_schema) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
from typing import Union | |
@dataclasses.dataclass | |
class Permission: | |
key: str | |
description: str = None | |
parent: Union["Permission", None] = None | |
def __repr__(self): | |
return self.id | |
@property | |
def id(self) -> str: | |
p = [] | |
p.append(self.key) | |
parent = self.parent | |
while parent: | |
if parent: | |
p.append(parent.key) | |
parent = parent.parent | |
p.reverse() | |
return str(".".join(p)) | |
permission = Permission( | |
key="permission", | |
description="The root permission" | |
) | |
sub_permission = Permission( | |
key="sub", | |
description="im the sub" | |
parent=permission | |
) | |
super_sub_permission = Permission( | |
key="supersub", | |
description="im the super sub sub" | |
parent=sub_permission | |
) | |
another_super_sub_permission = Permission( | |
key="supersub2", | |
description="im the super sub sub" | |
parent=sub_permission | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ProfileModel(BaseModel): | |
id: Optional[int] = None | |
name: Optional[str] = None | |
username: Optional[str] = None | |
password: Optional[str] = None | |
token: Optional[str] = None | |
permissions: list[str] = Field(default_factory=list) | |
# permissions here is just a list like ["permission","permission.perm2"] | |
def has_permissions(self, permissions: Union[list[str], list[Permission], None] = None) -> bool: | |
if permissions is None: | |
permissions = [] | |
if isinstance(permissions, Permission): | |
permissions = [str(permissions)] | |
permissions = [str(x) for x in permissions] | |
for permission in permissions: | |
if str(permission) not in self.permissions: | |
return False | |
return True | |
def is_authenticated(self) -> bool: | |
return self.id is not None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from permissions import permission, sub_permission, super_sub_permission, another_super_sub_permission | |
# this route requires a "logged in user" | |
@router.get("") | |
async def get_all_instances( | |
dependencies=[Authorization()] | |
): | |
... | |
# all these routes require user.permissions['permission'] | |
router = APIRouter( | |
tags=["instances"], | |
prefix="/instances", | |
dependencies=[Authorization(permission)] | |
) | |
# this route requires a "logged in user", must have "permission" from the router. and now must have "permission.sub" as well for the route | |
@router.get("") | |
async def get_all_instances( | |
dependencies=[Authorization(sub_permission)] | |
): | |
... | |
# requires "permission" - router | |
# requires "permission.sub.supersub" and "permission.sub.supersub2" | |
@router.get("") | |
async def get_all_instances( | |
dependencies=[Authorization([super_sub_permission,another_super_sub_permission])] | |
): | |
... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import typing | |
from fastapi import Depends, Security | |
from fastapi.security import ( | |
APIKeyHeader, | |
APIKeyQuery, | |
HTTPAuthorizationCredentials, | |
HTTPBearer, | |
) | |
# noinspection PyUnresolvedReferences | |
from starlette.requests import Request | |
from domain.user.model import ProfileModel | |
from domain.user.repository import get_user_by_token, update_token | |
from domain.user.service import get_user_from_username_password | |
QUERY_API_KEY = APIKeyQuery( | |
name="token", | |
auto_error=False, | |
description="Pass the token via a query sting item ?token=xxx" | |
) | |
BEARER = HTTPBearer(auto_error=False, description="Add the token to your bearer authentication") | |
HEADER_API_KEY = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Add a header [X-API-KEY] with the token") | |
# COOKIE_API_KEY = APIKeyCookie(name='token', auto_error=False, description="Login via cookie") | |
logger = logging.getLogger(__name__) | |
async def get_token( | |
request: Request, | |
query_api_key: str = Security(QUERY_API_KEY), | |
bearer_key: HTTPAuthorizationCredentials = Security(BEARER), | |
header_api_key: str = Security(HEADER_API_KEY), | |
) -> typing.Optional[str]: | |
if bearer_key: | |
bearer_key = bearer_key.credentials | |
cookie_token = request.cookies.get("token") | |
return next( | |
(arg for arg in [query_api_key, bearer_key, header_api_key, cookie_token] if arg is not None), | |
None | |
) | |
async def get_current_user(token=Depends(get_token)) -> ProfileModel: | |
user = ProfileModel() | |
if token: | |
sp = token.split(" ") | |
username = sp[0] if len(sp) > 0 else False | |
password = sp[1] if len(sp) > 1 else False | |
if username and password: | |
logger.debug(f"Retrieving user by username,password bearer: {username}") | |
user = await get_user_from_username_password(username, password) | |
else: | |
logger.debug("Retrieving user by token") | |
user = await get_user_by_token(token=token) | |
await update_token(token=token) | |
# user = await find_by_token(token=token) | |
logger.debug(f"current user: {user}") | |
return user | |
CurrentUser = typing.Annotated[ProfileModel, Depends(get_current_user)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
a bit more of a write up on this fastapi/fastapi#10692