Skip to content

Instantly share code, notes, and snippets.

@encryptblockr
Forked from nilsdebruin/fastapi_basic.py
Created March 18, 2021 06:43
Show Gist options
  • Save encryptblockr/59e8dd6eab1046a33e8ff4fc37b5c5fb to your computer and use it in GitHub Desktop.
Save encryptblockr/59e8dd6eab1046a33e8ff4fc37b5c5fb to your computer and use it in GitHub Desktop.
FastAPI Basic Authentication - Based on https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/
from typing import Optional
import base64
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
from jwt import PyJWTError
from pydantic import BaseModel
from fastapi import Depends, FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.security import OAuth2PasswordRequestForm, OAuth2
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.openapi.utils import get_openapi
from starlette.status import HTTP_403_FORBIDDEN
from starlette.responses import RedirectResponse, Response, JSONResponse
from starlette.requests import Request
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "[email protected]",
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
"disabled": False,
}
}
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str = None
class User(BaseModel):
username: str
email: str = None
full_name: str = None
disabled: bool = None
class UserInDB(User):
hashed_password: str
class OAuth2PasswordBearerCookie(OAuth2):
def __init__(
self,
tokenUrl: str,
scheme_name: str = None,
scopes: dict = None,
auto_error: bool = True,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
async def __call__(self, request: Request) -> Optional[str]:
header_authorization: str = request.headers.get("Authorization")
cookie_authorization: str = request.cookies.get("Authorization")
header_scheme, header_param = get_authorization_scheme_param(
header_authorization
)
cookie_scheme, cookie_param = get_authorization_scheme_param(
cookie_authorization
)
if header_scheme.lower() == "bearer":
authorization = True
scheme = header_scheme
param = header_param
elif cookie_scheme.lower() == "bearer":
authorization = True
scheme = cookie_scheme
param = cookie_param
else:
authorization = False
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return param
class BasicAuth(SecurityBase):
def __init__(self, scheme_name: str = None, auto_error: bool = True):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return param
basic_auth = BasicAuth(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token")
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except PyJWTError:
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
@app.get("/")
async def homepage():
return "Welcome to the security test!"
@app.post("/token", response_model=Token)
async def route_login_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/logout")
async def route_logout_and_remove_cookie():
response = RedirectResponse(url="/")
response.delete_cookie("Authorization", domain="localtest.me")
return response
@app.get("/login_basic")
async def login_basic(auth: BasicAuth = Depends(basic_auth)):
if not auth:
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
return response
try:
decoded = base64.b64decode(auth).decode("ascii")
username, _, password = decoded.partition(":")
user = authenticate_user(fake_users_db, username, password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect email or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": username}, expires_delta=access_token_expires
)
token = jsonable_encoder(access_token)
response = RedirectResponse(url="/docs")
response.set_cookie(
"Authorization",
value=f"Bearer {token}",
domain="localtest.me",
httponly=True,
max_age=1800,
expires=1800,
)
return response
except:
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
return response
@app.get("/openapi.json")
async def get_open_api_endpoint(current_user: User = Depends(get_current_active_user)):
return JSONResponse(get_openapi(title="FastAPI", version=1, routes=app.routes))
@app.get("/docs")
async def get_documentation(current_user: User = Depends(get_current_active_user)):
return get_swagger_ui_html(openapi_url="/openapi.json", title="docs")
@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(get_current_active_user)):
return [{"item_id": "Foo", "owner": current_user.username}]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment