Last active
February 23, 2025 22:22
-
-
Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Request, status | |
from fastapi.responses import PlainTextResponse | |
import jwt | |
import logging | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
@app.middleware("http") | |
async def auth_middleware(request: Request, call_next): | |
auth_header = request.headers.get("Authorization") | |
# Response for unauthorized requests | |
resp401 = PlainTextResponse("Unauthorized", status_code=status.HTTP_401_UNAUTHORIZED) | |
# Check if Authorization header is valid | |
if auth_header: | |
# Get the token from the header | |
token = auth_header.split("Bearer ")[1] | |
if not token: | |
return resp401 | |
try: | |
decoded_token = validate_token(token) | |
if decoded_token: | |
# Here you can do some authorization logic like checking scopes, roles, etc. | |
# But we don't, we just chain the request to the next middleware | |
response = await call_next(request) | |
return response | |
except Exception as e: | |
logger.error(f"ERROR: Problem validating token: {e}") | |
return resp401 | |
else: | |
return resp401 | |
def validate_token(token: str): | |
jwks_client = jwt.PyJWKClient( | |
# Magic URL you might want to put in a config file or constant | |
uri="https://login.microsoftonline.com/common/discovery/keys", | |
cache_jwk_set=True, | |
lifespan=600 | |
) | |
signing_key = jwks_client.get_signing_key_from_jwt(token) | |
return jwt.decode( | |
token, | |
signing_key.key, | |
# This is the algorithm that Azure AD uses and lots of other OIDC providers | |
algorithms=["RS256"], | |
# For your API, this will be the Application ID (GUID) of the client you have registered | |
audience="b79fbf4d-3ef9-4689-8143-76b194e85509", | |
) | |
# Just a simple endpoint to demonstrate the middleware | |
@app.get("/") | |
def read_root(): | |
return {"Hello": "World"} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Request | |
from dotenv import load_dotenv | |
import uvicorn | |
import requests | |
import os | |
load_dotenv() | |
app = FastAPI() | |
# The SNS_TOPIC_ARN env var must be set | |
allowed_topics = [ | |
os.environ['SNS_TOPIC_ARN'] | |
] | |
# Main endpoint to receive SNS messages | |
@app.post("/sns/subscription") | |
async def index(request: Request): | |
if 'x-amz-sns-message-type' not in request.headers: | |
print('A regular HTTP request, will be ignored!') | |
return {"message": "Request filtered out", "status": 403} | |
body = await request.json() | |
message_type = request.headers.get('x-amz-sns-message-type', '') | |
topic_arn = body["TopicArn"] | |
if topic_arn not in allowed_topics: | |
return {"message": "Topic not allowed", "status": 403} | |
# More checks can be added here, like checking the signature of the message | |
# Anyhow, this is just a simple example | |
if message_type == 'SubscriptionConfirmation': | |
print('Subscription confirmation in progress...') | |
subscribe_url = body['SubscribeURL'] | |
# Confirm the subscription by sending a GET request to the SubscribeURL | |
response = requests.get(subscribe_url) | |
if response.status_code == 200: | |
return {"message": "Successfully subscribed to the topic", "status": 200} | |
else: | |
return {"message": "Failed to subscribe to the topic", "status": 500} | |
sns_message = body.get('Message', '') | |
print(f'SNS message received, length: {len(sns_message)} bytes') | |
# ============================================ | |
# Your message processing logic goes here! | |
# ============================================ | |
return {"message": "Message received", "status": 200} | |
# Health check endpoint for Front Door and other load balancing probes | |
@app.get("/healthz") | |
async def index(request: Request): | |
return {"status": "OK"} | |
if __name__ == '__main__': | |
uvicorn.run('main:app', host='0.0.0.0', port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment