Skip to content

Instantly share code, notes, and snippets.

@dasiths
Created March 8, 2025 17:35
Show Gist options
  • Save dasiths/c85f74458b57aeb45188782b186de471 to your computer and use it in GitHub Desktop.
Save dasiths/c85f74458b57aeb45188782b186de471 to your computer and use it in GitHub Desktop.
How to handle custom business specific context propagation in FastAPI and httpx using OpenTelemetry
import os
from typing import Optional, Dict, Any
from contextvars import ContextVar
from fastapi import FastAPI, Request, Depends, HTTPException
import httpx
import logging
import uvicorn
from functools import wraps
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.sdk.trace import TracerProvider, ReadableSpan
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Resource
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
# Define context variables for business attributes
business_unit_var: ContextVar[Optional[str]] = ContextVar('business_unit', default=None)
client_id_var: ContextVar[Optional[str]] = ContextVar('client_id', default=None)
# Custom headers for business context
BUSINESS_UNIT_HEADER = "X-Business-Unit"
CLIENT_ID_HEADER = "X-Client-ID"
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(title="OpenTelemetry Example with contextvars")
# Define a custom span processor that adds business context attributes to spans
class BusinessContextSpanProcessor(SpanProcessor):
def on_start(self, span: ReadableSpan, parent_context = None) -> None:
# Add business context attributes to the span
business_unit = business_unit_var.get()
client_id = client_id_var.get()
if business_unit:
span.set_attribute("business_unit", business_unit)
if client_id:
span.set_attribute("client_id", client_id)
def on_end(self, span: ReadableSpan) -> None:
pass
# Initialize OpenTelemetry
def setup_opentelemetry():
# Create a resource with service information
resource = Resource.create({
"service.name": "platform-api",
"service.version": "1.0.0",
})
# Configure trace provider
trace_provider = TracerProvider(resource=resource)
# Add our custom span processor
trace_provider.add_span_processor(BusinessContextSpanProcessor())
# Add standard OTLP exporter
otlp_exporter = OTLPSpanExporter(endpoint=os.getenv("OTLP_ENDPOINT", "localhost:4317"))
trace_provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
# Set global trace provider
trace.set_tracer_provider(trace_provider)
# Configure metrics
meter_provider = MeterProvider(resource=resource)
metrics.set_meter_provider(meter_provider)
metric_exporter = OTLPMetricExporter(endpoint=os.getenv("OTLP_ENDPOINT", "localhost:4317"))
# Set tracer
tracer = trace.get_tracer("platform.api")
# Set meter
meter = metrics.get_meter("platform.api")
# Create a counter as an example metric
request_counter = meter.create_counter(
name="http.requests",
description="Counts HTTP requests",
unit="1",
)
return tracer, meter, request_counter
tracer, meter, request_counter = setup_opentelemetry()
# Custom httpx client with header propagation
class BusinessContextHTTPClient:
"""Custom HTTP client that automatically adds business context headers"""
def __init__(self):
self.client = httpx.AsyncClient()
async def request(self, method, url, **kwargs):
# Get headers dict, creating it if it doesn't exist
headers = kwargs.get("headers", {})
# Add business context headers from context vars
business_unit = business_unit_var.get()
client_id = client_id_var.get()
if business_unit:
headers[BUSINESS_UNIT_HEADER] = business_unit
if client_id:
headers[CLIENT_ID_HEADER] = client_id
# Update kwargs with modified headers
kwargs["headers"] = headers
# Make the request with propagated context
return await self.client.request(method, url, **kwargs)
async def get(self, url, **kwargs):
return await self.request("GET", url, **kwargs)
async def post(self, url, **kwargs):
return await self.request("POST", url, **kwargs)
async def put(self, url, **kwargs):
return await self.request("PUT", url, **kwargs)
async def delete(self, url, **kwargs):
return await self.request("DELETE", url, **kwargs)
async def patch(self, url, **kwargs):
return await self.request("PATCH", url, **kwargs)
async def close(self):
await self.client.aclose()
# Dependency to get HTTP client
async def get_http_client():
client = BusinessContextHTTPClient()
try:
yield client
finally:
await client.close()
# Custom request hooks for FastAPIInstrumentor
def request_hook(span: trace.Span, scope: Dict[str, Any]):
"""Extract business context from request headers and store in context vars"""
if not span or not scope:
return
headers = dict(scope.get("headers", []))
# Convert byte headers to strings (FastAPI provides headers as tuples of bytes)
str_headers = {k.decode('utf8').lower(): v.decode('utf8')
for k, v in headers.items()}
# Extract business headers
business_unit = str_headers.get(BUSINESS_UNIT_HEADER.lower())
client_id = str_headers.get(CLIENT_ID_HEADER.lower())
# Store in context vars
if business_unit:
business_unit_var.set(business_unit)
span.set_attribute("business_unit", business_unit)
if client_id:
client_id_var.set(client_id)
span.set_attribute("client_id", client_id)
# Record metric with business attributes
attributes = {}
if business_unit:
attributes["business_unit"] = business_unit
if client_id:
attributes["client_id"] = client_id
request_counter.add(1, attributes)
# Log with context
path = scope.get("path", "")
method = scope.get("method", "")
logger.info(
f"Received request: {method} {path}",
extra={"business_unit": business_unit, "client_id": client_id}
)
def response_hook(span: trace.Span, message: Dict[str, Any]):
"""Hook for handling response - can be used for additional processing if needed"""
pass
# Instrument FastAPI with custom hooks
FastAPIInstrumentor.instrument_app(
app,
request_hook=request_hook,
response_hook=response_hook,
tracer_provider=trace.get_tracer_provider(),
)
# Instrument HTTPX for general tracing (but we'll handle header propagation separately)
HTTPXInstrumentor.instrument(tracer_provider=trace.get_tracer_provider())
# Instrument logging
LoggingInstrumentor().instrument()
# Example endpoint
@app.get("/api/resource")
async def get_resource(
request: Request,
http_client: BusinessContextHTTPClient = Depends(get_http_client)
):
with tracer.start_as_current_span("get_resource") as span:
logger.info("Processing resource request")
# Business context headers will be automatically added by our custom HTTP client
try:
response = await http_client.get(
"https://downstream-service.example.com/api/data",
timeout=10.0
)
return {"status": "success", "data": response.json()}
except Exception as e:
logger.error(f"Error calling downstream service: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing request")
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment