Created
January 17, 2025 23:51
-
-
Save lukestanley/2577d0b8fcb02e678b202fe0fd924b15 to your computer and use it in GitHub Desktop.
llama_cpp_server_model_swapping_proxy_middleware.py (set variables as needed. MIT license.)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Minimalist OpenAI API compatiable llama.cpp server dynamic models switching middleware / proxy server manages model loading and auto-shutdown.. | |
# Provides seamless model hot-swapping and idle shutdown while exposing llama.cpp's advanced features like speculative decoding.""" | |
import asyncio | |
import json | |
from datetime import datetime, timedelta | |
import subprocess | |
import aiohttp | |
from aiohttp import web | |
# Constants | |
PROXY_PORT = 8000 # Intended as an external port | |
SERVER_PORT = 8312 # Intended as an internal port | |
IDLE_TIMEOUT_SECONDS = 60 * 100 # 100 minutes | |
SERVER_READY_TIMEOUT_SECONDS = 90 | |
DEFAULT_MODEL = "/fast/Meta-Llama-3.1-8B-Instruct-IQ4_XS.gguf" | |
DEFAULT_DRAFT_MODEL = "/fast/Llama-3.2-1B-Instruct-IQ4_XS.gguf" | |
CHUNK_SIZE = 4096 | |
MAX_GPU_LAYERS = 99 | |
DEFAULT_CTX_SIZE = 4096 | |
DEFAULT_THREADS = 8 | |
LLAMA_SERVER_CMD = """ | |
/fast/llama_gpu/bin/llama-server | |
--model {model} | |
--model-draft {draft_model} | |
--ctx-size {ctx_size} | |
--threads {threads} | |
--port {port} | |
-fa | |
-ngl {gpu_layers} | |
--gpu-layers-draft {gpu_layers} | |
""" | |
KEY_MODEL = "model" | |
KEY_DRAFT_MODEL = "draft_model" | |
KEY_PROCESS = "process" | |
KEY_LAST_REQUEST = "last_request" | |
# Shared State | |
state = { | |
KEY_PROCESS: None, | |
KEY_LAST_REQUEST: None, | |
KEY_MODEL: DEFAULT_MODEL, | |
KEY_DRAFT_MODEL: DEFAULT_DRAFT_MODEL, | |
} | |
async def is_server_ready(): | |
"""Check if the Llama server is ready.""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(f"http://127.0.0.1:{SERVER_PORT}/health") as resp: | |
return resp.status == 200 | |
except aiohttp.ClientConnectorError: | |
return False | |
async def wait_for_server_ready(): | |
"""Wait for the Llama server to be ready.""" | |
start_time = datetime.now() | |
while not await is_server_ready(): | |
if datetime.now() - start_time > timedelta( | |
seconds=SERVER_READY_TIMEOUT_SECONDS | |
): | |
raise Exception("Server failed to become ready") | |
await asyncio.sleep(1) | |
async def start_server( | |
model=None, draft_model=None, ctx_size=DEFAULT_CTX_SIZE, threads=DEFAULT_THREADS | |
): | |
"""Start the Llama server.""" | |
model = model or state[KEY_MODEL] | |
draft_model = draft_model or state[KEY_DRAFT_MODEL] | |
if state[KEY_PROCESS] and ( | |
state[KEY_MODEL] != model or state[KEY_DRAFT_MODEL] != draft_model | |
): | |
print("Stopping server to switch models...") | |
await stop_server() | |
if state[KEY_PROCESS] is None: | |
cmd = LLAMA_SERVER_CMD.format( | |
model=model, | |
draft_model=draft_model, | |
ctx_size=ctx_size, | |
threads=threads, | |
port=SERVER_PORT, | |
gpu_layers=MAX_GPU_LAYERS, | |
).split() | |
state[KEY_PROCESS] = await asyncio.create_subprocess_exec( | |
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
) | |
asyncio.create_task(stream_subprocess_output(state[KEY_PROCESS])) | |
state[KEY_MODEL] = model | |
state[KEY_DRAFT_MODEL] = draft_model | |
state[KEY_LAST_REQUEST] = datetime.now() | |
print( | |
f"Llama server started on port {SERVER_PORT} with models: {model}, {draft_model}" | |
) | |
await wait_for_server_ready() | |
async def stop_server(): | |
"""Stop the Llama server.""" | |
if state[KEY_PROCESS]: | |
print("Terminating Llama server process...") | |
state[KEY_PROCESS].terminate() | |
await state[KEY_PROCESS].wait() | |
print("Llama server process terminated.") | |
state[KEY_PROCESS] = None | |
async def stream_subprocess_output(process): | |
"""Stream subprocess output.""" | |
async def stream_pipe(pipe): | |
while line := await pipe.readline(): | |
print(line.decode(), end="") | |
await asyncio.gather(stream_pipe(process.stdout), stream_pipe(process.stderr)) | |
async def monitor_idle_timeout(): | |
"""Monitor server idle timeout.""" | |
while True: | |
await asyncio.sleep(5) | |
if state[KEY_LAST_REQUEST] and datetime.now() - state[ | |
KEY_LAST_REQUEST | |
] > timedelta(seconds=IDLE_TIMEOUT_SECONDS): | |
print("Stopping server due to inactivity...") | |
await stop_server() | |
state[KEY_LAST_REQUEST] = None | |
def adapt_request_for_llama(headers, body): | |
"""Adapt request for Llama.cpp server.""" | |
try: | |
request_json = json.loads(body) | |
model = request_json.pop("model", None) | |
draft_model = request_json.pop("draft_model", None) | |
body = json.dumps(request_json).encode() | |
headers["content-type"] = "application/json" | |
headers["content-length"] = str(len(body)) | |
return headers, body, model, draft_model | |
except json.JSONDecodeError: | |
return headers, body, None, None | |
async def proxy_request(request): | |
"""Handle incoming requests using aiohttp.web.""" | |
try: | |
# Read request body | |
body = await request.read() | |
headers = dict(request.headers) | |
# Adapt request for Llama | |
headers, body, model, draft_model = adapt_request_for_llama(headers, body) | |
# Start the Llama Server (if needed) | |
await start_server(model, draft_model) | |
state[KEY_LAST_REQUEST] = datetime.now() | |
# Proxy to Llama.cpp Server | |
async with aiohttp.ClientSession() as session: | |
async with session.request( | |
method=request.method, | |
url=f"http://127.0.0.1:{SERVER_PORT}{request.path}", | |
headers=headers, | |
data=body | |
) as llama_response: | |
# Stream response from Llama server | |
response = web.StreamResponse( | |
status=llama_response.status, | |
headers=llama_response.headers | |
) | |
await response.prepare(request) | |
async for chunk in llama_response.content.iter_any(): | |
await response.write(chunk) | |
await response.write_eof() | |
return response | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return web.Response(status=500, text="Internal Server Error") | |
async def main(): | |
"""Main function.""" | |
await start_server() | |
asyncio.create_task(monitor_idle_timeout()) | |
app = web.Application() | |
app.router.add_route('*', '/{path:.*}', proxy_request) | |
runner = web.AppRunner(app) | |
await runner.setup() | |
site = web.TCPSite(runner, '127.0.0.1', PROXY_PORT) | |
print(f"Serving on http://127.0.0.1:{PROXY_PORT}") | |
await site.start() | |
try: | |
await asyncio.Future() # run forever | |
finally: | |
await runner.cleanup() | |
asyncio.run(main()) | |
""" | |
Why not just use the Llama server directly? | |
-That's fine until you want to switch models, or save resources when not in use. | |
-It's bleeding edge and supports the latest models with speculative decoding to speed up completions. | |
-It also can support JSON schema enforcement and other great features. | |
Ollama is a similar project but it is more complicated and lags behind llama.cpp. | |
Python-llama-cpp is a Python wrapper around llama.cpp but it is not as feature rich as llama.cpp. | |
This is a minimalist server that is easy to understand and modify. | |
The standard OpenAI API compatibility is implemented by Llama.cpp Server. | |
The OpenAI style API is even offered by Google's Gemini API. | |
""" | |
""" | |
This works but how could it be more minimalist while preserving the primary functionality? | |
httpx may help. FastAPI and Flask would not. | |
This is middleware. | |
FastAPI is a nice wrapper around Starlette, which is a nice wrapper around anyio IIRC. | |
config managment is not a big chunk of the code. Celery would be overkill. | |
We don't need logging libraries. | |
I want to keep everything in one file. | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment