Created
March 8, 2025 17:35
-
-
Save dasiths/c85f74458b57aeb45188782b186de471 to your computer and use it in GitHub Desktop.
How to handle custom business specific context propagation in FastAPI and httpx using OpenTelemetry
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from typing import Optional, Dict, Any | |
| from contextvars import ContextVar | |
| from fastapi import FastAPI, Request, Depends, HTTPException | |
| import httpx | |
| import logging | |
| import uvicorn | |
| from functools import wraps | |
| # OpenTelemetry imports | |
| from opentelemetry import trace, metrics | |
| from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter | |
| from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter | |
| from opentelemetry.sdk.trace import TracerProvider, ReadableSpan | |
| from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanProcessor | |
| from opentelemetry.sdk.metrics import MeterProvider | |
| from opentelemetry.sdk.resources import Resource | |
| from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor | |
| from opentelemetry.instrumentation.httpx import HTTPXInstrumentor | |
| from opentelemetry.instrumentation.logging import LoggingInstrumentor | |
| # Define context variables for business attributes | |
| business_unit_var: ContextVar[Optional[str]] = ContextVar('business_unit', default=None) | |
| client_id_var: ContextVar[Optional[str]] = ContextVar('client_id', default=None) | |
| # Custom headers for business context | |
| BUSINESS_UNIT_HEADER = "X-Business-Unit" | |
| CLIENT_ID_HEADER = "X-Client-ID" | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Create FastAPI app | |
| app = FastAPI(title="OpenTelemetry Example with contextvars") | |
| # Define a custom span processor that adds business context attributes to spans | |
| class BusinessContextSpanProcessor(SpanProcessor): | |
| def on_start(self, span: ReadableSpan, parent_context = None) -> None: | |
| # Add business context attributes to the span | |
| business_unit = business_unit_var.get() | |
| client_id = client_id_var.get() | |
| if business_unit: | |
| span.set_attribute("business_unit", business_unit) | |
| if client_id: | |
| span.set_attribute("client_id", client_id) | |
| def on_end(self, span: ReadableSpan) -> None: | |
| pass | |
| # Initialize OpenTelemetry | |
| def setup_opentelemetry(): | |
| # Create a resource with service information | |
| resource = Resource.create({ | |
| "service.name": "platform-api", | |
| "service.version": "1.0.0", | |
| }) | |
| # Configure trace provider | |
| trace_provider = TracerProvider(resource=resource) | |
| # Add our custom span processor | |
| trace_provider.add_span_processor(BusinessContextSpanProcessor()) | |
| # Add standard OTLP exporter | |
| otlp_exporter = OTLPSpanExporter(endpoint=os.getenv("OTLP_ENDPOINT", "localhost:4317")) | |
| trace_provider.add_span_processor(BatchSpanProcessor(otlp_exporter)) | |
| # Set global trace provider | |
| trace.set_tracer_provider(trace_provider) | |
| # Configure metrics | |
| meter_provider = MeterProvider(resource=resource) | |
| metrics.set_meter_provider(meter_provider) | |
| metric_exporter = OTLPMetricExporter(endpoint=os.getenv("OTLP_ENDPOINT", "localhost:4317")) | |
| # Set tracer | |
| tracer = trace.get_tracer("platform.api") | |
| # Set meter | |
| meter = metrics.get_meter("platform.api") | |
| # Create a counter as an example metric | |
| request_counter = meter.create_counter( | |
| name="http.requests", | |
| description="Counts HTTP requests", | |
| unit="1", | |
| ) | |
| return tracer, meter, request_counter | |
| tracer, meter, request_counter = setup_opentelemetry() | |
| # Custom httpx client with header propagation | |
| class BusinessContextHTTPClient: | |
| """Custom HTTP client that automatically adds business context headers""" | |
| def __init__(self): | |
| self.client = httpx.AsyncClient() | |
| async def request(self, method, url, **kwargs): | |
| # Get headers dict, creating it if it doesn't exist | |
| headers = kwargs.get("headers", {}) | |
| # Add business context headers from context vars | |
| business_unit = business_unit_var.get() | |
| client_id = client_id_var.get() | |
| if business_unit: | |
| headers[BUSINESS_UNIT_HEADER] = business_unit | |
| if client_id: | |
| headers[CLIENT_ID_HEADER] = client_id | |
| # Update kwargs with modified headers | |
| kwargs["headers"] = headers | |
| # Make the request with propagated context | |
| return await self.client.request(method, url, **kwargs) | |
| async def get(self, url, **kwargs): | |
| return await self.request("GET", url, **kwargs) | |
| async def post(self, url, **kwargs): | |
| return await self.request("POST", url, **kwargs) | |
| async def put(self, url, **kwargs): | |
| return await self.request("PUT", url, **kwargs) | |
| async def delete(self, url, **kwargs): | |
| return await self.request("DELETE", url, **kwargs) | |
| async def patch(self, url, **kwargs): | |
| return await self.request("PATCH", url, **kwargs) | |
| async def close(self): | |
| await self.client.aclose() | |
| # Dependency to get HTTP client | |
| async def get_http_client(): | |
| client = BusinessContextHTTPClient() | |
| try: | |
| yield client | |
| finally: | |
| await client.close() | |
| # Custom request hooks for FastAPIInstrumentor | |
| def request_hook(span: trace.Span, scope: Dict[str, Any]): | |
| """Extract business context from request headers and store in context vars""" | |
| if not span or not scope: | |
| return | |
| headers = dict(scope.get("headers", [])) | |
| # Convert byte headers to strings (FastAPI provides headers as tuples of bytes) | |
| str_headers = {k.decode('utf8').lower(): v.decode('utf8') | |
| for k, v in headers.items()} | |
| # Extract business headers | |
| business_unit = str_headers.get(BUSINESS_UNIT_HEADER.lower()) | |
| client_id = str_headers.get(CLIENT_ID_HEADER.lower()) | |
| # Store in context vars | |
| if business_unit: | |
| business_unit_var.set(business_unit) | |
| span.set_attribute("business_unit", business_unit) | |
| if client_id: | |
| client_id_var.set(client_id) | |
| span.set_attribute("client_id", client_id) | |
| # Record metric with business attributes | |
| attributes = {} | |
| if business_unit: | |
| attributes["business_unit"] = business_unit | |
| if client_id: | |
| attributes["client_id"] = client_id | |
| request_counter.add(1, attributes) | |
| # Log with context | |
| path = scope.get("path", "") | |
| method = scope.get("method", "") | |
| logger.info( | |
| f"Received request: {method} {path}", | |
| extra={"business_unit": business_unit, "client_id": client_id} | |
| ) | |
| def response_hook(span: trace.Span, message: Dict[str, Any]): | |
| """Hook for handling response - can be used for additional processing if needed""" | |
| pass | |
| # Instrument FastAPI with custom hooks | |
| FastAPIInstrumentor.instrument_app( | |
| app, | |
| request_hook=request_hook, | |
| response_hook=response_hook, | |
| tracer_provider=trace.get_tracer_provider(), | |
| ) | |
| # Instrument HTTPX for general tracing (but we'll handle header propagation separately) | |
| HTTPXInstrumentor.instrument(tracer_provider=trace.get_tracer_provider()) | |
| # Instrument logging | |
| LoggingInstrumentor().instrument() | |
| # Example endpoint | |
| @app.get("/api/resource") | |
| async def get_resource( | |
| request: Request, | |
| http_client: BusinessContextHTTPClient = Depends(get_http_client) | |
| ): | |
| with tracer.start_as_current_span("get_resource") as span: | |
| logger.info("Processing resource request") | |
| # Business context headers will be automatically added by our custom HTTP client | |
| try: | |
| response = await http_client.get( | |
| "https://downstream-service.example.com/api/data", | |
| timeout=10.0 | |
| ) | |
| return {"status": "success", "data": response.json()} | |
| except Exception as e: | |
| logger.error(f"Error calling downstream service: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Error processing request") | |
| # Health check endpoint | |
| @app.get("/health") | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment