Last active
August 1, 2024 19:46
-
-
Save charlesfrye/fd595d21e2d483cb71ace23bde6430c0 to your computer and use it in GitHub Desktop.
LLaMA 3.1 405B Instruct FP8 - vLLM - OpenAI-compatible server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import modal | |
vllm_image = modal.Image.debian_slim(python_version="3.10").pip_install( | |
[ | |
"vllm==0.5.3post1", # LLM serving | |
"huggingface_hub==0.24.1", # download models from the Hugging Face Hub | |
"hf-transfer==0.1.8", # download models faster | |
] | |
) | |
MODEL_NAME = "meta-llama/Meta-Llama-3.1-405B-Instruct-FP8" | |
MODEL_REVISION = "d8e5bf570eac69f7dfc596cfaaebe6acbf95ca2e" | |
MODEL_DIR = f"/models/{MODEL_NAME}" | |
MINUTES = 60 # seconds | |
HOURS = 60 * MINUTES | |
GIGABYTES = 1024 # megabytes | |
app = modal.App("vllm-openai-compatible-405b") | |
N_GPU = 8 | |
TOKEN = ( | |
"super-secret-token" # auth token. for production use, replace with a modal.Secret | |
) | |
volume = modal.Volume.from_name("llama3-405b-fp8", create_if_missing=True) | |
@app.function( | |
image=vllm_image, | |
gpu=modal.gpu.A100(count=N_GPU, size="80GB"), | |
memory=336 * GIGABYTES, # max | |
container_idle_timeout=20 * MINUTES, | |
timeout=1 * HOURS, | |
allow_concurrent_inputs=100, | |
volumes={MODEL_DIR: volume}, | |
) | |
@modal.asgi_app() | |
def serve(): | |
import asyncio | |
import fastapi | |
import vllm.entrypoints.openai.api_server as api_server | |
from vllm.engine.arg_utils import AsyncEngineArgs | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat | |
from vllm.entrypoints.openai.serving_completion import ( | |
OpenAIServingCompletion, | |
) | |
from vllm.entrypoints.logger import RequestLogger | |
from vllm.usage.usage_lib import UsageContext | |
volume.reload() | |
# create a fastAPI app that uses vLLM's OpenAI-compatible router | |
app = fastapi.FastAPI( | |
title=f"OpenAI-compatible {MODEL_NAME} server", | |
description="Run an OpenAI-compatible LLM server with vLLM on modal.com", | |
version="0.0.1", | |
docs_url="/docs", | |
) | |
# security: CORS middleware for external requests | |
http_bearer = fastapi.security.HTTPBearer( | |
scheme_name="Bearer Token", description="See code for authentication details." | |
) | |
app.add_middleware( | |
fastapi.middleware.cors.CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# security: inject dependency on authed routes | |
async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): | |
if api_key.credentials != TOKEN: | |
raise fastapi.HTTPException( | |
status_code=fastapi.status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid authentication credentials", | |
) | |
return {"username": "authenticated_user"} | |
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) | |
router.include_router(api_server.router) | |
app.include_router(router) | |
engine_args = AsyncEngineArgs( | |
model=MODEL_DIR, | |
tensor_parallel_size=N_GPU, | |
gpu_memory_utilization=0.90, | |
max_model_len=1024 + 128, | |
enforce_eager=True, | |
) | |
engine = AsyncLLMEngine.from_engine_args( | |
engine_args, usage_context=UsageContext.OPENAI_API_SERVER | |
) | |
try: # copied from vLLM -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 | |
event_loop = asyncio.get_running_loop() | |
except RuntimeError: | |
event_loop = None | |
if event_loop is not None and event_loop.is_running(): | |
# If the current is instanced by Ray Serve, | |
# there is already a running event loop | |
model_config = event_loop.run_until_complete(engine.get_model_config()) | |
else: | |
# When using single vLLM without engine_use_ray | |
model_config = asyncio.run(engine.get_model_config()) | |
request_logger = RequestLogger(max_log_len=2048) | |
api_server.openai_serving_chat = OpenAIServingChat( | |
engine, | |
model_config=model_config, | |
served_model_names=[MODEL_DIR], | |
chat_template=None, | |
response_role="assistant", | |
lora_modules=[], | |
prompt_adapters=[], | |
request_logger=request_logger, | |
) | |
api_server.openai_serving_completion = OpenAIServingCompletion( | |
engine, | |
model_config=model_config, | |
served_model_names=[MODEL_DIR], | |
lora_modules=[], | |
prompt_adapters=[], | |
request_logger=request_logger, | |
) | |
return app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This simple script shows how to interact with an OpenAI-compatible server from a client.""" | |
import argparse | |
import modal | |
from openai import OpenAI | |
class Colors: | |
"""ANSI color codes""" | |
GREEN = "\033[0;32m" | |
RED = "\033[0;31m" | |
BLUE = "\033[0;34m" | |
GRAY = "\033[0;90m" | |
BOLD = "\033[1m" | |
END = "\033[0m" | |
def get_completion(client, model_id, messages, args): | |
completion_args = { | |
"model": model_id, | |
"messages": messages, | |
"frequency_penalty": args.frequency_penalty, | |
"max_tokens": args.max_tokens, | |
"n": args.n, | |
"presence_penalty": args.presence_penalty, | |
"seed": args.seed, | |
"stop": args.stop, | |
"stream": args.stream, | |
"temperature": args.temperature, | |
"top_p": args.top_p, | |
} | |
completion_args = {k: v for k, v in completion_args.items() if v is not None} | |
try: | |
response = client.chat.completions.create(**completion_args) | |
return response | |
except Exception as e: | |
print(Colors.RED, f"Error during API call: {e}", Colors.END, sep="") | |
return None | |
def main(): | |
parser = argparse.ArgumentParser(description="OpenAI Client CLI") | |
parser.add_argument( | |
"--model", | |
type=str, | |
default=None, | |
help="The model to use for completion, defaults to the first available model", | |
) | |
parser.add_argument( | |
"--api-key", | |
type=str, | |
default="super-secret-token", | |
help="The API key to use for authentication, set in your api.py", | |
) | |
# Completion parameters | |
parser.add_argument("--max-tokens", type=int, default=None) | |
parser.add_argument("--temperature", type=float, default=0.7) | |
parser.add_argument("--top-p", type=float, default=0.9) | |
parser.add_argument("--top-k", type=int, default=0) | |
parser.add_argument("--frequency-penalty", type=float, default=0) | |
parser.add_argument("--presence-penalty", type=float, default=0) | |
parser.add_argument( | |
"--n", | |
type=int, | |
default=1, | |
help="Number of completions to generate. Streaming and chat mode only support n=1.", | |
) | |
parser.add_argument("--stop", type=str, default=None) | |
parser.add_argument("--seed", type=int, default=None) | |
# Prompting | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default="Compose a limerick about baboons and racoons.", | |
help="The user prompt for the chat completion", | |
) | |
parser.add_argument( | |
"--system-prompt", | |
type=str, | |
default="You are a poetic assistant, skilled in writing satirical doggerel with creative flair.", | |
help="The system prompt for the chat completion", | |
) | |
# UI options | |
parser.add_argument( | |
"--no-stream", | |
dest="stream", | |
action="store_false", | |
help="Disable streaming of response chunks", | |
) | |
parser.add_argument( | |
"--chat", action="store_true", help="Enable interactive chat mode" | |
) | |
args = parser.parse_args() | |
client = OpenAI(api_key=args.api_key) | |
WORKSPACE = modal.config._profile | |
client.base_url = ( | |
f"https://{WORKSPACE}--vllm-openai-compatible-405b-serve.modal.run/v1" | |
) | |
if args.model: | |
model_id = args.model | |
print( | |
Colors.BOLD, | |
f"🧠: Using model {model_id}. This may trigger a boot on first call!", | |
Colors.END, | |
sep="", | |
) | |
else: | |
print( | |
Colors.BOLD, | |
f"🔎: Looking up available models on server at {client.base_url}. This may trigger a boot!", | |
Colors.END, | |
sep="", | |
) | |
model = client.models.list().data[0] | |
model_id = model.id | |
print( | |
Colors.BOLD, | |
f"🧠: Using {model_id}", | |
Colors.END, | |
sep="", | |
) | |
messages = [ | |
{ | |
"role": "system", | |
"content": args.system_prompt, | |
} | |
] | |
print(Colors.BOLD + "🧠: Using system prompt: " + args.system_prompt + Colors.END) | |
if args.chat: | |
print( | |
Colors.GREEN | |
+ Colors.BOLD | |
+ "\nEntering chat mode. Type 'bye' to end the conversation." | |
+ Colors.END | |
) | |
while True: | |
user_input = input("\nYou: ") | |
if user_input.lower() in ["bye"]: | |
break | |
MAX_HISTORY = 10 | |
if len(messages) > MAX_HISTORY: | |
messages = messages[:1] + messages[-MAX_HISTORY + 1 :] | |
messages.append({"role": "user", "content": user_input}) | |
response = get_completion(client, model_id, messages, args) | |
if response: | |
if args.stream: | |
# only stream assuming n=1 | |
print(Colors.BLUE + "\n🤖: ", end="") | |
assistant_message = "" | |
for chunk in response: | |
if chunk.choices[0].delta.content: | |
content = chunk.choices[0].delta.content | |
print(content, end="") | |
assistant_message += content | |
print(Colors.END) | |
else: | |
assistant_message = response.choices[0].message.content | |
print( | |
Colors.BLUE + "\n🤖:" + assistant_message + Colors.END, | |
sep="", | |
) | |
messages.append({"role": "assistant", "content": assistant_message}) | |
else: | |
messages.append({"role": "user", "content": args.prompt}) | |
print(Colors.GREEN + f"\nYou: {args.prompt}" + Colors.END) | |
response = get_completion(client, model_id, messages, args) | |
if response: | |
if args.stream: | |
print(Colors.BLUE + "\n🤖:", end="") | |
for chunk in response: | |
if chunk.choices[0].delta.content: | |
print(chunk.choices[0].delta.content, end="") | |
print(Colors.END) | |
else: | |
# only case where multiple completions are returned | |
for i, response in enumerate(response.choices): | |
print( | |
Colors.BLUE | |
+ f"\n🤖 Choice {i+1}:{response.message.content}" | |
+ Colors.END, | |
sep="", | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import modal | |
MODEL_NAME = "meta-llama/Meta-Llama-3.1-405B-Instruct-FP8" | |
MODEL_REVISION = "a8f01524ffd5c05a7de914a51fae0b5afe738d3b" | |
MODEL_DIR = f"/models/{MODEL_NAME}" | |
volume = modal.Volume.from_name("llama3-405b-fp8", create_if_missing=True) | |
image = ( | |
modal.Image.debian_slim(python_version="3.10") | |
.pip_install( | |
[ | |
"vllm==0.5.3post1", # LLM serving | |
"huggingface_hub", # download models from the Hugging Face Hub | |
"hf-transfer", # download models faster | |
] | |
) | |
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) | |
) | |
MINUTES = 60 | |
HOURS = 60 * MINUTES | |
app = modal.App(image=image, secrets=[modal.Secret.from_name("huggingface")]) | |
# should take about 30 minutes | |
@app.function(volumes={MODEL_DIR: volume}, timeout=4 * HOURS) | |
def download_model(model_dir, model_name, model_revision): | |
import os | |
from huggingface_hub import snapshot_download | |
volume.reload() | |
os.makedirs(model_dir, exist_ok=True) | |
snapshot_download( | |
model_name, | |
local_dir=model_dir, | |
ignore_patterns=["*.pt", "*.bin", "*.pth", "original/*"], # Ensure safetensors | |
revision=model_revision, | |
) | |
volume.commit() | |
@app.local_entrypoint() | |
def main(): | |
download_model.remote(MODEL_DIR, MODEL_NAME, MODEL_REVISION) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment