Skip to content

Instantly share code, notes, and snippets.

@benc-uk
Last active February 23, 2025 22:22
Show Gist options
  • Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Python

Python Things

from fastapi import FastAPI, Request, status
from fastapi.responses import PlainTextResponse
import jwt
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization")
# Response for unauthorized requests
resp401 = PlainTextResponse("Unauthorized", status_code=status.HTTP_401_UNAUTHORIZED)
# Check if Authorization header is valid
if auth_header:
# Get the token from the header
token = auth_header.split("Bearer ")[1]
if not token:
return resp401
try:
decoded_token = validate_token(token)
if decoded_token:
# Here you can do some authorization logic like checking scopes, roles, etc.
# But we don't, we just chain the request to the next middleware
response = await call_next(request)
return response
except Exception as e:
logger.error(f"ERROR: Problem validating token: {e}")
return resp401
else:
return resp401
def validate_token(token: str):
jwks_client = jwt.PyJWKClient(
# Magic URL you might want to put in a config file or constant
uri="https://login.microsoftonline.com/common/discovery/keys",
cache_jwk_set=True,
lifespan=600
)
signing_key = jwks_client.get_signing_key_from_jwt(token)
return jwt.decode(
token,
signing_key.key,
# This is the algorithm that Azure AD uses and lots of other OIDC providers
algorithms=["RS256"],
# For your API, this will be the Application ID (GUID) of the client you have registered
audience="b79fbf4d-3ef9-4689-8143-76b194e85509",
)
# Just a simple endpoint to demonstrate the middleware
@app.get("/")
def read_root():
return {"Hello": "World"}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from fastapi import FastAPI, Request
from dotenv import load_dotenv
import uvicorn
import requests
import os
load_dotenv()
app = FastAPI()
# The SNS_TOPIC_ARN env var must be set
allowed_topics = [
os.environ['SNS_TOPIC_ARN']
]
# Main endpoint to receive SNS messages
@app.post("/sns/subscription")
async def index(request: Request):
if 'x-amz-sns-message-type' not in request.headers:
print('A regular HTTP request, will be ignored!')
return {"message": "Request filtered out", "status": 403}
body = await request.json()
message_type = request.headers.get('x-amz-sns-message-type', '')
topic_arn = body["TopicArn"]
if topic_arn not in allowed_topics:
return {"message": "Topic not allowed", "status": 403}
# More checks can be added here, like checking the signature of the message
# Anyhow, this is just a simple example
if message_type == 'SubscriptionConfirmation':
print('Subscription confirmation in progress...')
subscribe_url = body['SubscribeURL']
# Confirm the subscription by sending a GET request to the SubscribeURL
response = requests.get(subscribe_url)
if response.status_code == 200:
return {"message": "Successfully subscribed to the topic", "status": 200}
else:
return {"message": "Failed to subscribe to the topic", "status": 500}
sns_message = body.get('Message', '')
print(f'SNS message received, length: {len(sns_message)} bytes')
# ============================================
# Your message processing logic goes here!
# ============================================
return {"message": "Message received", "status": 200}
# Health check endpoint for Front Door and other load balancing probes
@app.get("/healthz")
async def index(request: Request):
return {"status": "OK"}
if __name__ == '__main__':
uvicorn.run('main:app', host='0.0.0.0', port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment