Created
August 12, 2025 22:51
-
-
Save darknoon/a590ec51739c19b1dcab816b4a4572a4 to your computer and use it in GitHub Desktop.
Example of how to proxy from outer → inner in modal (not secure)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import modal | |
app = modal.App("test-proxy-to-modal-web-server") | |
# Amount of time the fake vllm server takes to start | |
STARTUP_DELAY = 10 | |
# Amount of time we should expect the proxy / vllm to be responsive within | |
START_TIMEOUT = 20 | |
PROXY_TO_VLLM_TIMEOUT = START_TIMEOUT | |
VLLM_TIMEOUT = START_TIMEOUT + 10 | |
VLLM_SCALEDOWN = 10 | |
VLLM_PORT = 8088 | |
@app.cls( | |
image=modal.Image.debian_slim().pip_install("fastapi", "uvicorn"), | |
scaledown_window=VLLM_SCALEDOWN, | |
timeout=VLLM_TIMEOUT, | |
) | |
@modal.concurrent(max_inputs=32) | |
class VLLMSimulator: | |
# Different server for each model | |
model_name: str = modal.parameter() | |
# NOTE: This is still public on the internet, which is not intended. I think I need to use proxy auth tokens to fix that and pass to the proxy? | |
@modal.web_server(port=VLLM_PORT, startup_timeout=START_TIMEOUT) | |
def serve(self): | |
from fastapi import FastAPI, Request, HTTPException | |
import uvicorn | |
import time | |
from multiprocessing import Process | |
import asyncio | |
def fake_vllm(model_name: str, port: int): | |
app = FastAPI() | |
@app.on_event("startup") | |
async def startup_delay(): | |
print(f"sleeping for {STARTUP_DELAY}s to simulate model load") | |
await asyncio.sleep(STARTUP_DELAY) # simulate model load | |
@app.post("/v1/chat/completions") | |
async def chat(request: Request): | |
body = await request.json() | |
model = body.get("model") | |
if not model: | |
raise HTTPException(400, "Model name required") | |
if model != model_name: | |
raise HTTPException( | |
400, | |
f"Model name {model} does not match {model_name}", | |
) | |
return {"choices": [{"message": {"content": f"Hello from {model}!"}}]} | |
config = uvicorn.Config( | |
app=app, host="0.0.0.0", port=port, log_level="info" | |
) | |
server = uvicorn.Server(config) | |
# Important when not in main thread/process signals land differently; this avoids issues | |
server.install_signal_handlers = False | |
server.run() | |
# inside your @modal.web_server def serve(self): | |
# time.sleep(STARTUP_DELAY) | |
print("starting fake vllm server process") | |
proc = Process(target=fake_vllm, args=(self.model_name, VLLM_PORT), daemon=True) | |
proc.start() | |
print("called start on process") | |
# Proxy server that calls the vLLM simulator (like your actual proxy) | |
@app.function( | |
image=modal.Image.debian_slim().pip_install("fastapi", "uvicorn", "httpx") | |
) | |
@modal.asgi_app() | |
def proxy_server(): | |
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.responses import StreamingResponse | |
import httpx | |
import json | |
proxy = FastAPI() | |
# Get the VLLM simulator URL | |
base_url = VLLMSimulator(model_name="test-model").serve.get_web_url() | |
print(f"Proxy to base URL: {base_url}") | |
@proxy.post("/v1/chat/completions") | |
async def chat_completions(request: Request): | |
"""Proxy chat completions to vLLM simulator, always streaming response.""" | |
print("Proxy: request received, reading body.") | |
raw_body = await request.body() | |
print("Proxy: raw body:", raw_body) | |
try: | |
body = json.loads(raw_body) | |
except json.JSONDecodeError: | |
raise HTTPException(status_code=400, detail="Invalid JSON body") | |
model_name = body.get("model") | |
if not model_name: | |
raise HTTPException(status_code=400, detail="Model name required") | |
target_url = f"{base_url}/v1/chat/completions?model_name={model_name}" | |
req_headers = {"content-type": "application/json"} | |
print(f"Proxy: POST {target_url}") | |
async with httpx.AsyncClient(timeout=PROXY_TO_VLLM_TIMEOUT) as client: | |
try: | |
response = await client.post( | |
target_url, | |
content=raw_body, | |
headers=req_headers, | |
) | |
return StreamingResponse( | |
iter([response.content]), | |
status_code=response.status_code, | |
media_type=response.headers.get("content-type") or None, | |
headers={ | |
"content-type": response.headers.get( | |
"content-type", "application/json" | |
) | |
}, | |
) | |
except httpx.TimeoutException: | |
raise HTTPException( | |
status_code=503, | |
detail="vLLM server timed out, it could be starting up, please try again in a few minutes", | |
) | |
except httpx.ConnectError: | |
raise HTTPException( | |
status_code=503, | |
detail="vLLM server could not be contacted, check the logs for more details", | |
) | |
@proxy.get("/health") | |
async def health(): | |
"""Proxy health check.""" | |
return {"status": "healthy", "type": "proxy"} | |
return proxy | |
# Test function | |
@app.function(image=modal.Image.debian_slim().pip_install("httpx", "tenacity")) | |
async def run_test(): | |
import httpx | |
import asyncio | |
# vllm_url = VLLMSimulator(model_name="test-model").serve.get_web_url() | |
proxy_url = proxy_server.get_web_url() | |
print(f"Testing proxy: {proxy_url}") | |
# Test /health endpoint (hits proxy) | |
async with httpx.AsyncClient() as client: | |
resp = await client.get(f"{proxy_url}/health") | |
print(f"/health status {resp.status_code}, response: {resp.json()}") | |
# Trigger a chat completion for model-a (hits proxy -> starts vllm for that model) | |
async with httpx.AsyncClient(timeout=START_TIMEOUT) as client: | |
chat_a = await client.post( | |
f"{proxy_url}/v1/chat/completions", | |
json={"model": "model-a", "messages": [{"role": "user", "content": "Hi"}]}, | |
) | |
print(f"✅ Success: {chat_a.status_code}") | |
print(f"Response: {chat_a.json()}") | |
chat_b = await client.post( | |
f"{proxy_url}/v1/chat/completions", | |
json={"model": "model-b", "messages": [{"role": "user", "content": "Hi"}]}, | |
) | |
print(f"✅ Success: {chat_b.status_code}") | |
print(f"Response: {chat_b.json()}") | |
print("Test completed!") | |
if __name__ == "__main__": | |
# Run the test | |
import asyncio | |
asyncio.run(run_test.remote()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment