|
""" |
|
WebAuthn Implementation with FastAPI |
|
|
|
Implement WebAuthn (Web Authentication) in a FastAPI application. |
|
Provide endpoints for registration and authentication. |
|
""" |
|
|
|
import os |
|
import base64 |
|
import logging |
|
import secrets |
|
import uuid |
|
from datetime import datetime, timedelta, UTC |
|
import hashlib |
|
import threading |
|
import sqlite3 |
|
import time |
|
from sqlite3 import Connection |
|
from typing import Dict, Optional |
|
from pydantic import BaseModel |
|
|
|
import jwt |
|
from jwt.exceptions import PyJWTError |
|
|
|
from fastapi import FastAPI, HTTPException, Depends |
|
from fastapi.responses import JSONResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.security import OAuth2PasswordBearer |
|
|
|
import webauthn |
|
from webauthn.helpers.structs import ( |
|
PublicKeyCredentialCreationOptions, |
|
PublicKeyCredentialRequestOptions, |
|
PublicKeyCredentialUserEntity, |
|
PublicKeyCredentialRpEntity, |
|
RegistrationCredential, |
|
AuthenticationCredential, |
|
) |
|
from common import DatabaseInterface, https_server |
|
|
|
logger = logging.getLogger('main') |
|
|
|
|
|
class Server: |
|
""" |
|
Class representing the server |
|
""" |
|
def __init__(self, db: DatabaseInterface): |
|
self.database = db |
|
self.revoked_tokens = set() |
|
|
|
def register_user(self, username: str, password: str = None) -> bool: |
|
""" |
|
Register a new user. |
|
For WebAuthn, password is None. |
|
For password auth, password is required. |
|
Returns True if registration successful, False if username taken. |
|
""" |
|
if password is None: |
|
# WebAuthn registration, empty password |
|
return self.database.add_user(username, "") |
|
password_hash = self._hash_password(password) |
|
return self.database.add_user(username, password_hash) |
|
|
|
def authenticate_user(self, username: str, password: str) -> bool: |
|
""" |
|
Verify user credentials. |
|
""" |
|
password_hash = self._hash_password(password) |
|
return self.database.verify_credentials(username, password_hash) |
|
|
|
def create_challenge(self, username: str) -> Optional[str]: |
|
""" |
|
Create and store a new challenge for a user. |
|
Returns None if user doesn't exist. |
|
""" |
|
if not self.database.get_user(username): |
|
return None |
|
|
|
challenge = secrets.token_hex(32) |
|
if self.database.add_challenge(username, challenge): |
|
return challenge |
|
return None |
|
|
|
def verify_challenge(self, username: str, challenge: str) -> bool: |
|
""" |
|
Verify that a challenge matches what's stored for the user. |
|
""" |
|
stored_challenge = self.database.get_challenge(username) |
|
return stored_challenge is not None and stored_challenge == challenge |
|
|
|
@staticmethod |
|
def _hash_password(password: str) -> str: |
|
"""Hash a password using SHA-256.""" |
|
return hashlib.sha256(password.encode()).hexdigest() |
|
|
|
|
|
class SQLiteDatabase(DatabaseInterface): |
|
""" |
|
Wrapper to make SQLite follow the DatabaseInstance interface |
|
""" |
|
def __init__(self): |
|
self._local = threading.local() |
|
|
|
@property |
|
def conn(self) -> Connection: |
|
""" |
|
Get a db connection. If necessary, create one, and |
|
populate in-memory database. This is for development only. |
|
""" |
|
if not hasattr(self._local, "conn"): |
|
self._local.conn = sqlite3.connect(':memory:') |
|
self._init_tables(self._local.conn) |
|
return self._local.conn |
|
|
|
def _init_tables(self, conn: Connection): |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
CREATE TABLE IF NOT EXISTS users ( |
|
username TEXT PRIMARY KEY, |
|
password_hash TEXT NOT NULL |
|
) |
|
""") |
|
cursor.execute(""" |
|
CREATE TABLE IF NOT EXISTS challenges ( |
|
username TEXT PRIMARY KEY, |
|
challenge TEXT NOT NULL, |
|
FOREIGN KEY (username) REFERENCES users(username) |
|
) |
|
""") |
|
cursor.execute(""" |
|
CREATE TABLE IF NOT EXISTS credentials ( |
|
user_id TEXT PRIMARY KEY, |
|
token TEXT NOT NULL, |
|
FOREIGN KEY (user_id) REFERENCES users(username) |
|
) |
|
""") |
|
conn.commit() |
|
|
|
def add_user(self, username: str, password_hash: str) -> bool: |
|
cursor = self.conn.cursor() |
|
try: |
|
cursor.execute( |
|
"INSERT INTO users (username, password_hash) VALUES (?, ?)", |
|
(username, password_hash) |
|
) |
|
self.conn.commit() |
|
return True |
|
except sqlite3.IntegrityError: |
|
return False |
|
except sqlite3.Error: # Catch other potential SQLite errors |
|
logger.error("Database error while adding user") |
|
return False |
|
|
|
def get_user(self, username: str) -> Optional[Dict]: |
|
cursor = self.conn.cursor() |
|
cursor.execute( |
|
"SELECT username, password_hash FROM users WHERE username = ?", |
|
(username,) |
|
) |
|
row = cursor.fetchone() |
|
if row: |
|
return {"username": row[0], "password_hash": row[1]} |
|
return None |
|
|
|
def verify_credentials(self, username: str, password_hash: str) -> bool: |
|
cursor = self.conn.cursor() |
|
cursor.execute( |
|
"SELECT 1 FROM users WHERE username = ? AND password_hash = ?", |
|
(username, password_hash) |
|
) |
|
return cursor.fetchone() is not None |
|
|
|
def add_challenge(self, username: str, challenge: str) -> bool: |
|
cursor = self.conn.cursor() |
|
try: |
|
cursor.execute( |
|
"""INSERT OR REPLACE INTO challenges |
|
(username, challenge) VALUES (?, ?)""", |
|
(username, challenge) |
|
) |
|
self.conn.commit() |
|
return True |
|
except sqlite3.Error: |
|
return False |
|
|
|
def get_challenge(self, username: str) -> Optional[str]: |
|
cursor = self.conn.cursor() |
|
cursor.execute( |
|
"SELECT challenge FROM challenges WHERE username = ?", |
|
(username,) |
|
) |
|
row = cursor.fetchone() |
|
return row[0] if row else None |
|
|
|
def clear_users(self): |
|
cursor = self.conn.cursor() |
|
cursor.execute("DELETE FROM users") |
|
self.conn.commit() |
|
|
|
def clear_credentials(self): |
|
cursor = self.conn.cursor() |
|
cursor.execute("DELETE FROM credentials") |
|
self.conn.commit() |
|
|
|
def clear_challenges(self): |
|
cursor = self.conn.cursor() |
|
cursor.execute("DELETE FROM challenges") |
|
self.conn.commit() |
|
|
|
|
|
database = SQLiteDatabase() |
|
server = Server(database) |
|
|
|
# Initialize FastAPI |
|
app = FastAPI(title="FastAPI WebAuthn Example") |
|
|
|
# This would be your database in a real application |
|
users_db = {} |
|
credentials_db = {} |
|
|
|
# Store challenges temporarily (in a real app, use Redis or similar) |
|
challenge_db = {} |
|
|
|
# Store revoked tokens |
|
revoked_tokens = set() |
|
|
|
# Configuration for your application |
|
RELYING_PARTY_ID = "127.0.0.1" # Your domain name |
|
RELYING_PARTY_NAME = "FastAPI WebAuthn Example" |
|
RELYING_PARTY_ORIGIN = "https://127.0.0.1:8000/" |
|
|
|
# Add these configuration variables after other config variables |
|
SECRET_KEY = os.environ.get("SECRET_KEY") or secrets.token_urlsafe(32) |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
|
|
# Pydantic models for request validation |
|
class RegisterStartRequest(BaseModel): |
|
""" |
|
Request model for initiating WebAuthn registration. |
|
|
|
Attributes: |
|
username: The username to register with the system |
|
""" |
|
username: str |
|
|
|
|
|
class RegisterCompleteRequest(BaseModel): |
|
""" |
|
Request model for completing WebAuthn registration. |
|
|
|
Attributes: |
|
username: The username being registered |
|
credential: The WebAuthn credential data from the authenticator |
|
""" |
|
username: str |
|
credential: dict |
|
|
|
|
|
class LoginStartRequest(BaseModel): |
|
""" |
|
Request model for initiating WebAuthn authentication. |
|
|
|
Attributes: |
|
username: The username attempting to log in |
|
""" |
|
username: str |
|
|
|
|
|
class LoginCompleteRequest(BaseModel): |
|
""" |
|
Request model for completing WebAuthn authentication. |
|
|
|
Attributes: |
|
username: The username being authenticated |
|
credential: The WebAuthn assertion data from the authenticator |
|
""" |
|
username: str |
|
credential: dict |
|
|
|
|
|
# Utility functions |
|
def generate_challenge() -> str: |
|
"""Generate a random challenge for WebAuthn operations""" |
|
random_bytes = secrets.token_bytes(32) |
|
return base64.urlsafe_b64encode(random_bytes).decode('utf-8').rstrip('=') |
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): |
|
"""Create a JWT access token""" |
|
to_encode = data.copy() |
|
if expires_delta: |
|
expire = datetime.now(UTC) + expires_delta |
|
else: |
|
expire = (datetime.now(UTC) + |
|
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
|
|
# Add this near the top with other app initialization |
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)): |
|
"""Verify JWT token and return current user""" |
|
credentials_exception = HTTPException( |
|
status_code=401, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
try: |
|
# Check if token is revoked |
|
if token in revoked_tokens: |
|
raise HTTPException( |
|
status_code=401, |
|
detail="Token has been revoked", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
username: str = payload.get("sub") |
|
if username is None: |
|
raise credentials_exception |
|
if username not in users_db: |
|
raise credentials_exception |
|
return username |
|
except PyJWTError as exc: |
|
raise credentials_exception from exc |
|
|
|
|
|
# API Endpoints |
|
@app.post("/register/start") |
|
async def register_start(request: RegisterStartRequest): |
|
"""Start the registration process""" |
|
logger.info("Start the registration process") |
|
|
|
# Check if user already exists in users_db |
|
if request.username in users_db: |
|
raise HTTPException(status_code=400, detail="User already exists") |
|
|
|
# Now this will work with just a username |
|
if not server.register_user(request.username): |
|
raise HTTPException(status_code=400, detail="User already exists") |
|
|
|
# Generate a new user ID |
|
user_id = str(uuid.uuid4()) |
|
|
|
# Store the user (in a real app, you'd save to a database) |
|
users_db[request.username] = { |
|
"id": user_id, |
|
"username": request.username, |
|
"credentials": [] |
|
} |
|
|
|
# Generate a challenge |
|
challenge = generate_challenge() |
|
challenge_db[request.username] = challenge |
|
|
|
# Create WebAuthn registration options |
|
options = PublicKeyCredentialCreationOptions( |
|
rp=PublicKeyCredentialRpEntity(id=RELYING_PARTY_ID, |
|
name=RELYING_PARTY_NAME), |
|
user=PublicKeyCredentialUserEntity( |
|
id=str(user_id).encode('utf-8'), # Use the byte representation |
|
name=request.username, |
|
display_name=request.username, |
|
), |
|
challenge=challenge, |
|
pub_key_cred_params=[ |
|
{"type": "public-key", "alg": -7}, # ES256 |
|
{"type": "public-key", "alg": -257} # RS256 |
|
], |
|
timeout=60000, |
|
attestation="direct", |
|
authenticator_selection={ |
|
"authenticator_attachment": "platform", |
|
# or "cross-platform" for security keys |
|
"require_resident_key": False, |
|
"user_verification": "preferred", |
|
}, |
|
exclude_credentials=[], # No credentials to exclude for a new user |
|
) |
|
|
|
options_dict = { |
|
key: getattr(options, key) |
|
for key in dir(options) |
|
if not key.startswith("_") and key not in ("rp", "user") |
|
} |
|
return JSONResponse(content=options_dict) |
|
|
|
|
|
@app.post("/register/complete") |
|
async def register_complete(request: RegisterCompleteRequest): |
|
"""Complete the registration process""" |
|
logger.info("Complete the registration process") |
|
username = request.username |
|
|
|
logger.debug("Check if user exists") |
|
if username not in users_db: |
|
raise HTTPException(status_code=400, detail="User does not exist") |
|
|
|
logger.debug("Get the challenge") |
|
challenge = challenge_db.get(username) |
|
if not challenge: |
|
raise HTTPException(status_code=400, detail="No challenge found") |
|
|
|
logger.debug("Parse the credential") |
|
try: |
|
d = dict(request.credential) |
|
d['raw_id'] = d.pop('rawId') |
|
credential = RegistrationCredential(**d) |
|
|
|
logger.debug("Verify the registration") |
|
registration_verification = webauthn.verify_registration_response( |
|
credential=credential, |
|
expected_challenge=challenge, |
|
expected_origin=RELYING_PARTY_ORIGIN, |
|
expected_rp_id=RELYING_PARTY_ID, |
|
require_user_verification=False, |
|
) |
|
|
|
logger.debug("Store the credential") |
|
credential_id = registration_verification.credential_id |
|
public_key = registration_verification.credential_public_key |
|
|
|
logger.debug("In a real app, store these in a secure database") |
|
credentials_db[credential_id] = { |
|
"username": username, |
|
"public_key": public_key, |
|
"sign_count": registration_verification.sign_count, |
|
} |
|
|
|
logger.debug("Associate credential with user") |
|
users_db[username]["credentials"].append(credential_id) |
|
|
|
logger.debug("Clean up the challenge") |
|
del challenge_db[username] |
|
|
|
logger.debug("Registration successful") |
|
return {"status": "success", "message": "Registration successful"} |
|
|
|
except Exception as e: |
|
logger.error(e) |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f'Registration failed: {str(e)}' |
|
) from e |
|
|
|
|
|
@app.post("/login/start") |
|
async def login_start(request: LoginStartRequest): |
|
"""Start the login process""" |
|
logger.info("Start the login process") |
|
username = request.username |
|
|
|
logger.debug("Check if user exists") |
|
if username not in users_db: |
|
logger.debug("User does not exist") |
|
raise HTTPException(status_code=400, detail="User does not exist") |
|
|
|
logger.debug("Get user's credentials") |
|
user_credential_ids = users_db[username]["credentials"] |
|
if not user_credential_ids: |
|
raise HTTPException(status_code=400, |
|
detail="No credentials found for user") |
|
|
|
logger.debug("Create a list of allowed credentials") |
|
allowed_credentials = [] |
|
for cred_id in user_credential_ids: |
|
allowed_credentials.append({ |
|
"type": "public-key", |
|
"id": cred_id, |
|
}) |
|
|
|
logger.debug("Generate a challenge") |
|
challenge = generate_challenge() |
|
challenge_db[username] = challenge |
|
|
|
logger.debug("Create WebAuthn authentication options") |
|
options = PublicKeyCredentialRequestOptions( |
|
challenge=challenge, |
|
timeout=60000, |
|
rp_id=RELYING_PARTY_ID, |
|
allow_credentials=allowed_credentials, |
|
user_verification="preferred", |
|
) |
|
|
|
logger.debug("Return stuff") |
|
options_dict = { |
|
key: getattr(options, key) |
|
for key in dir(options) |
|
if not key.startswith("_") and key not in ("rp", "user") |
|
} |
|
return JSONResponse(content=options_dict) |
|
|
|
|
|
@app.post("/login/complete") |
|
async def login_complete(request: LoginCompleteRequest): |
|
"""Complete the login process""" |
|
logger.info("Complete the login process") |
|
username = request.username |
|
|
|
logger.debug("Check if user exists") |
|
if username not in users_db: |
|
raise HTTPException(status_code=400, detail="User does not exist") |
|
|
|
logger.debug("Get the challenge") |
|
challenge = challenge_db.get(username) |
|
if not challenge: |
|
raise HTTPException(status_code=400, detail="No challenge found") |
|
|
|
logger.debug("Parse the credential") |
|
try: |
|
d = dict(request.credential) |
|
d['raw_id'] = d.pop('rawId') |
|
credential = AuthenticationCredential(**d) |
|
|
|
logger.debug("Get credential data") |
|
credential_id = credential.id |
|
cred_data = credentials_db.get(credential_id) |
|
if not cred_data or cred_data["username"] != username: |
|
raise HTTPException(status_code=400, detail="Invalid credential") |
|
|
|
logger.debug("Verify the authentication") |
|
auth_verification = webauthn.verify_authentication_response( |
|
credential=credential, |
|
expected_challenge=challenge, |
|
expected_origin=RELYING_PARTY_ORIGIN, |
|
expected_rp_id=RELYING_PARTY_ID, |
|
credential_public_key=cred_data["public_key"], |
|
credential_current_sign_count=cred_data["sign_count"], |
|
require_user_verification=False, |
|
) |
|
|
|
logger.debug("Update the sign count") |
|
credentials_db[credential_id]["sign_count"] = \ |
|
auth_verification.new_sign_count |
|
|
|
logger.debug("Clean up the challenge") |
|
del challenge_db[username] |
|
|
|
# Create access token |
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
access_token = create_access_token( |
|
data={"sub": username}, # 'sub' is standard JWT claim for subject |
|
expires_delta=access_token_expires |
|
) |
|
|
|
logger.debug("Login successful") |
|
return { |
|
"status": "success", |
|
"message": "Login successful", |
|
"access_token": access_token, |
|
"token_type": "bearer" |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f'Authentication failed: {str(e)}' |
|
) from e |
|
|
|
|
|
@app.get("/protected") |
|
async def protected_route(current_user: str = Depends(get_current_user)): |
|
""" |
|
Protected endpoint that requires a valid JWT token for access. |
|
|
|
Args: |
|
current_user: The authenticated username, extracted from the JWT token |
|
|
|
Returns: |
|
dict: Response containing a greeting and protected data including: |
|
- message: Personalized greeting |
|
- data: Dictionary containing secret value and timestamp |
|
""" |
|
return { |
|
"message": f"Hello {current_user}", |
|
"data": { |
|
"secret_value": "This data is only accessible with a valid token", |
|
"timestamp": datetime.now(UTC).isoformat() |
|
} |
|
} |
|
|
|
|
|
@app.post("/token/revoke") |
|
async def revoke_token(token: str = Depends(oauth2_scheme)): |
|
"""Revoke a JWT token""" |
|
revoked_tokens.add(token) |
|
return {"status": "success", "message": "Token revoked successfully"} |
|
|
|
# Mount the MkDocs static site |
|
os.system("mkdocs build") |
|
app.mount("/", StaticFiles(directory="site", html=True), name="site") |
|
|
|
|
|
if __name__ == "__main__": |
|
with https_server(app, host="0.0.0.0", port=8000) as server: |
|
# Server is running here |
|
try: |
|
# Keep the main thread alive |
|
while not server.should_exit: |
|
time.sleep(1) |
|
except KeyboardInterrupt: |
|
logger.info("Received shutdown signal") |