Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Created November 21, 2024 14:45
Show Gist options
  • Save grahama1970/ce1b2e6ebb0d4686dd3377a6b1cb84b1 to your computer and use it in GitHub Desktop.
Save grahama1970/ce1b2e6ebb0d4686dd3377a6b1cb84b1 to your computer and use it in GitHub Desktop.
import asyncio
from loguru import logger
from verifaix.arangodb_helper.arango_client import connect_to_arango_client
async def truncate_cache_collection(arango_config, db=None):
logger.info(f"Attempting to truncate cache collection '{arango_config['cache_collection_name']}'")
if db is None:
logger.info(f"Connecting to ArangoDB at {arango_config['host']}")
db = await asyncio.to_thread(connect_to_arango_client, arango_config)
collection_name = arango_config['cache_collection_name']
# Check if the collection exists before attempting to truncate
if db.has_collection(collection_name):
collection = db.collection(collection_name)
await asyncio.to_thread(collection.truncate)
logger.info(f"Truncated cache collection '{collection_name}'")
else:
logger.warning(f"Collection '{collection_name}' does not exist. Skipping truncation.")
# config.py
import os
from dotenv import load_dotenv
# Load environment variables from the .env file
load_dotenv('../.env')
HF_TOKEN = os.getenv("HF_TOKEN")
RUNPOD_API_KEY = os.getenv("RUNPOD_API_KEY")
if not RUNPOD_API_KEY:
raise ValueError("RUNPOD_API_KEY is missing. Please set it in your environment variables.")
class ModelCapabilities:
TEXT = "text"
TEXT_AND_IMAGE = "text_and_image"
# Define model configurations
MODEL_CONFIGS = {
# Corrected 0.5B QWEN model
"SGLang-Qwen/Qwen2.5-1.5B-Instruct": {
"name": "SGLang-Qwen/Qwen2.5-1.5B-Instruct",
"image_name": "lmsysorg/sglang:latest",
"docker_args": (
"python3 -m sglang.launch_server "
"--model-path Qwen/Qwen2.5-0.5B "
"--mem-fraction-static 0.95 "
"--host 0.0.0.0 "
"--port 8000"
),
"cloud_type": "SECURE",
"volume_in_gb": 5,
"ports": "8000/http",
"container_disk_in_gb": 10,
"volume_mount_path": "/root/.cache/huggingface",
"env": {"HF_TOKEN": os.getenv("HF_TOKEN"), "HF_HUB_ENABLE_HF_TRANSFER": "1"},
"preferred_gpu_names": ["RTX 4090", "RTX 4080", "RTX 6000 Ada", "RTX A6000"],
},
# 32B QWEN Instruct
"SGLang-Qwen/Qwen2.5-Coder-32B-Instruct": {
"name": "SGLang-Qwen/Qwen2.5-Coder-32B-Instruct",
"image_name": "lmsysorg/sglang:latest",
"docker_args": (
"python3 -m sglang.launch_server "
"--model-path Qwen/Qwen2.5-Coder-32B-Instruct "
"--mem-fraction-static 0.95 "
"--host 0.0.0.0 "
"--port 8000"
),
"cloud_type": "SECURE",
"volume_in_gb": 100,
"ports": "8000/http",
"container_disk_in_gb": 50,
"volume_mount_path": "/root/.cache/huggingface",
"env": {"HF_TOKEN": os.getenv("HF_TOKEN"), "HF_HUB_ENABLE_HF_TRANSFER": "1"},
"preferred_gpu_names": ["H100 PCIe", "H100 NVL", "H100 SXM", "RTX A6000"],
},
}
# Default settings for pods
DEFAULT_POD_SETTINGS = {
"image_name": "lmsysorg/sglang:latest",
"cloud_type": "SECURE",
"ports": "8000/http",
"container_disk_in_gb": 10,
"volume_in_gb": 100,
"volume_mount_path": "/root/.cache/huggingface",
"env": {
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"HF_TOKEN": HF_TOKEN
},
"scale_cooldown": 180,
"metrics_window": 60,
"monitor_interval": 15
}
# Configuration for the pipeline
pipeline_config = {
"arango_config": {
"host": "http://localhost:8529",
"username": "root",
"password": "openSesame",
"db_name": "verifaix",
"collection_name": "test_documents",
"cache_collection_name": "litellm_cache", # Store litellm responses
"truncate_cache": True # Truncate the cache collection before starting
},
"llm_config": {
"model": "openai/Qwen/Qwen2.5-0.5B",
"max_tokens": 50,
"temperature": 0.7,
"api_base": "api_base" # This will be set dynamically
}
}
# llm_utils.py
import asyncio
from loguru import logger
from verifaix.arangodb_helper.arango_client import generate_safe_key
from verifaix.llm_client.get_litellm_response import get_litellm_response
def add_hashes_to_requests(requests):
"""
Add hashes to each conversation in the list of requests.
Args:
requests (list): A list of conversations (each a list of messages).
Returns:
list: A list of dictionaries, each with '_hash' and 'messages' keys.
"""
def compute_hash(messages):
# Generate a hash using the content of the messages.
raw_key = " ".join(msg.get("content", "") for msg in messages)
return generate_safe_key(raw_key)
results = [
{
"_hash": compute_hash(messages),
"messages": messages
}
for messages in requests
]
return results
async def make_request(llm_params, request, db=None):
"""
Make a request to the LLM API with given parameters and messages.
Passes the hash directly into the LLM request.
Args:
llm_params (dict): LLM configuration parameters.
request (dict): A dictionary containing '_hash' and 'messages'.
db (object, optional): Database connection for additional context.
Returns:
dict: The response from the LLM API, including the hash.
"""
try:
# Extract hash and messages from the request
request_hash = request["_hash"]
messages = request["messages"]
# Pass the hash into the LLM request
response = await get_litellm_response(messages, llm_params, request_id=request_hash, db=db)
return response
except Exception as e:
logger.exception(f"Failed to process request with hash {request['_hash']}: {e}")
return {"_hidden_params": {"request_id": request["_hash"]}, "error": str(e)}
def merge_responses_with_requests(requests, responses):
"""
Merge the response objects back into the original requests list using hashes.
Args:
requests (list): The original list of requests with '_hash' fields.
responses (list): The list of response objects containing '_hidden_params' with hashes.
Returns:
list: The enriched list of requests, each with its corresponding response.
"""
# Map request_id (hash) as key to response as value
response_map = {
response["_hidden_params"]["request_id"]: response
for response in responses
}
# Merge the responses back into the original requests
for request in requests:
request_hash = request["_hash"]
if request_hash in response_map:
request["response"] = response_map[request_hash]
return requests
# pipeline.py
import asyncio
from datetime import datetime, timezone
import uuid
import jsonpickle
from loguru import logger
from verifaix.arangodb_helper.arango_client import connect_to_arango_client, upsert_document
import runpod
from config import pipeline_config, MODEL_CONFIGS, DEFAULT_POD_SETTINGS
from verifaix.runpod.utils.arango_utils import truncate_cache_collection
from verifaix.runpod.utils.runpod_ops import (
start_runpod_container,
wait_for_pod_to_run,
stop_runpod_container
)
from verifaix.runpod.utils.llm_utils import (
add_hashes_to_requests,
make_request,
merge_responses_with_requests
)
async def main():
"""
Main function to start the container, wait for it to be ready,
send a query, and stop the container.
"""
# Step 1: Initialize Model Configuration
logger.info("Step 1: Initialize Model Configuration")
model_name = "SGLang-Qwen/Qwen2.5-1.5B-Instruct"
model_config = MODEL_CONFIGS[model_name]
# Merge default pod settings
for key, value in DEFAULT_POD_SETTINGS.items():
model_config.setdefault(key, value)
try:
# Step 2: Start or Reuse RunPod Container
logger.info("Step 2: Start or Reuse RunPod Container")
# Check for existing pods
existing_pods = runpod.get_pods()
matching_pod = next(
(pod for pod in existing_pods if pod["name"] == model_config["name"]),
None
)
if matching_pod:
if matching_pod["desiredStatus"] == "RUNNING":
logger.info(f"Using existing running pod: {matching_pod['id']}")
pod = matching_pod
api_base = f"https://{pod['id']}-8000.proxy.runpod.net/v1"
elif matching_pod["desiredStatus"] in ["EXITED", "STOPPED"]:
logger.info(f"Cleaning up existing exited pod: {matching_pod['id']}")
runpod.terminate_pod(matching_pod["id"])
logger.info(f"Terminated pod: {matching_pod['id']}")
logger.info("Recreating a new pod...")
pod = await start_runpod_container(model_config)
pod, api_base = await wait_for_pod_to_run(pod)
else:
logger.warning(f"Pod found with unexpected status: {matching_pod['desiredStatus']}. Starting a new pod...")
pod = await start_runpod_container(model_config)
pod, api_base = await wait_for_pod_to_run(pod)
else:
logger.info("No existing pod found. Starting a new one...")
pod = await start_runpod_container(model_config)
pod, api_base = await wait_for_pod_to_run(pod)
# Step 3: Prepare Requests
logger.info("Step 3: Prepare Requests")
requests = [
[
{"role": "system", "content": "You are a knowledgeable historian who provides concise responses."},
{"role": "user", "content": "Tell me about ancient Rome"},
{"role": "assistant", "content": "Ancient Rome was a civilization centered in Italy."},
{"role": "user", "content": "What were their major achievements?"}
],
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"}
],
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the most common color of an apple?"}
]
]
requests_with_hashes = add_hashes_to_requests(requests)
# Step 4: Connect to Database
logger.info("Step 4: Connect to Database")
arango_config = pipeline_config['arango_config']
db = await asyncio.to_thread(connect_to_arango_client, arango_config)
# Truncate the litellm_cache collection (if set to True)
truncate_cache = arango_config.get('truncate_cache', False)
if truncate_cache:
await truncate_cache_collection(pipeline_config['arango_config'], db)
# Step 5: Prepare LLM Parameters
logger.info("Step 5: Prepare LLM Parameters")
llm_params = pipeline_config['llm_config']
llm_params['api_base'] = api_base # Ensure api_base is set
# Step 6: Make Requests to LLM
logger.info("Step 6: Make Requests to LLM")
tasks = [make_request(llm_params, request, db=db) for request in requests_with_hashes]
responses = await asyncio.gather(*tasks)
# Step 7: Merge Responses with Requests
logger.info("Step 7: Merge Responses with Requests")
connected_requests = merge_responses_with_requests(requests_with_hashes, responses)
# Step 8: Process and Store Results
logger.info("Step 8: Process and Store Results")
collection_name = arango_config.get('collection_name', 'default_collection')
for request in connected_requests:
try:
completion = request.get("response").get("choices")[0].get("message").get("content")
response_object = (
jsonpickle.encode(request.get("response"))
if request.get("response") else None
)
# Log the request for debugging
logger.info(f"Request ID: {request['_hash']}")
logger.info(f"Messages Object: {request['messages']}")
logger.info(f"Completion: {completion}")
# Prepare the document
document = {
"_key": str(uuid.uuid4()),
"request_id": request["_hash"], # Store the hash for traceability
"messages": request["messages"],
"completion": completion, # pulled out of response object
"response_object": response_object,
"_last_updated": datetime.now(timezone.utc).isoformat()
}
# Upsert the document (ArangoDB will generate the `_key`)
await asyncio.to_thread(upsert_document, db, collection_name, document)
logger.info(f"Upserted document with request_id: {request['_hash']} into collection '{collection_name}'")
except Exception as e:
logger.error(f"Failed to process request with request_id {request['_hash']}: {e}")
# Step 9: Clean Up
finally:
logger.info("Step 9: Clean Up")
if 'pod' in locals():
try:
logger.info(f"Stopping the runpod container: {pod['id']}")
await stop_runpod_container(pod["id"])
except Exception as cleanup_error:
logger.error(f"Failed to stop the container: {cleanup_error}")
if __name__ == "__main__":
asyncio.run(main())
# runpod_ops.py
import asyncio
from datetime import datetime, timezone
import httpx
import runpod
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from loguru import logger
from yaspin import yaspin
from yaspin.spinners import Spinners
from config import RUNPOD_API_KEY
# Set RunPod API key
runpod.api_key = RUNPOD_API_KEY
async def start_runpod_container(model_config):
"""
Start a RunPod container with retry logic, prioritizing a list of preferred GPUs.
Logs the time taken and handles pod recreation if needed.
"""
available_gpus = runpod.get_gpus() # Retrieve all available GPUs
if not available_gpus:
raise RuntimeError("No available GPUs found for pod creation.")
# Prioritize GPUs based on the preferred list and ensure the order matches `preferred_gpu_names`
preferred_gpus = sorted(
[gpu for gpu in available_gpus if gpu["displayName"] in model_config["preferred_gpu_names"]],
key=lambda gpu: model_config["preferred_gpu_names"].index(gpu["displayName"])
)
fallback_gpus = [gpu for gpu in available_gpus if gpu not in preferred_gpus]
if not preferred_gpus and not fallback_gpus:
raise RuntimeError("No suitable GPUs available for pod creation.")
gpus_to_try = preferred_gpus + fallback_gpus
for gpu in gpus_to_try:
try:
# Copy the model_config and remove non-relevant keys
pod_config = {key: value for key, value in model_config.items() if key != "preferred_gpu_names"}
pod_config["gpu_type_id"] = gpu["id"] # Use the GPU type ID
logger.info(f"Attempting to start pod with GPU: {gpu['displayName']} (ID: {gpu['id']})")
start_time = datetime.now(timezone.utc)
# Filter pod_config to only include valid keys for create_pod
# NOTE: This is a workaround to avoid errors due to invalid keys in the pod_config
valid_keys = [
'name', 'image_name', 'gpu_type_id', 'cloud_type', 'support_public_ip',
'start_ssh', 'data_center_id', 'country_code', 'gpu_count', 'volume_in_gb',
'container_disk_in_gb', 'min_vcpu_count', 'min_memory_in_gb', 'docker_args',
'ports', 'volume_mount_path', 'env', 'template_id', 'network_volume_id',
'allowed_cuda_versions', 'min_download', 'min_upload'
]
filtered_pod_config = {k: v for k, v in pod_config.items() if k in valid_keys}
pod_config = filtered_pod_config
pod = runpod.create_pod(**pod_config)
logger.info(f"Successfully started pod with GPU: {gpu['displayName']} (ID: {gpu['id']}). Pod ID: {pod['id']}")
end_time = datetime.now(timezone.utc)
startup_duration = (end_time - start_time).total_seconds()
logger.info(f"Container startup took {startup_duration:.2f} seconds.")
return pod
except Exception as e:
logger.warning(f"Failed to start pod with GPU {gpu['displayName']} (ID: {gpu['id']}): {e}")
logger.error("Failed to start a pod with any preferred or fallback GPU types.")
raise RuntimeError("No GPUs could be used to start the RunPod container.")
async def wait_for_pod_to_run(pod):
"""
Wait until the container is fully ready by checking its status and API readiness.
Logs if the pod was recreated due to cleanup or error handling.
"""
max_wait_time = 900 # 15 minutes
start_time = datetime.now(timezone.utc)
spinner = yaspin(Spinners.dots, text="Waiting for pod to initialize...")
spinner.start()
try:
while True:
elapsed_time = (datetime.now(timezone.utc) - start_time).total_seconds()
if elapsed_time > max_wait_time:
raise TimeoutError("Pod startup timed out.")
pod = runpod.get_pod(pod["id"])
logger.info(f"Pod details: {pod}")
if pod.get("desiredStatus") == "RUNNING":
logger.info("Pod has reached 'RUNNING' status.")
break
await asyncio.sleep(10)
# Construct the API base using the pod ID
api_base = f"https://{pod['id']}-8000.proxy.runpod.net/v1"
logger.info(f"API base URL: {api_base}")
# Perform readiness check
if await check_api_readiness(api_base):
return pod, api_base
else:
logger.warning("Pod required recreation due to readiness failure.")
raise RuntimeError("API readiness check failed.")
except Exception as e:
logger.exception("Error while waiting for the pod to be ready.")
raise
finally:
spinner.stop()
@retry(
stop=stop_after_attempt(30), # Retry up to 30 times
wait=wait_fixed(10), # Wait 10 seconds between retries
retry=retry_if_exception_type(Exception), # Retry only on exceptions
)
async def check_api_readiness(api_base):
"""
Check if the API is ready by pinging its endpoints.
Retries until a response with status code 200 is received or retries are exhausted.
Args:
api_base (str): The base URL of the API.
Returns:
bool: True if the API is ready, raises an Exception otherwise.
"""
endpoints = ["/models", "/health"]
async with httpx.AsyncClient() as client:
for endpoint in endpoints:
response = await client.get(f"{api_base}{endpoint}", timeout=10)
if response.status_code == 200:
logger.info(f"API readiness confirmed at endpoint: {endpoint}.")
return True
raise RuntimeError("API readiness check failed.")
@retry(
stop=stop_after_attempt(30), # Retry up to 30 times
wait=wait_fixed(5), # Wait 5 seconds between retries
retry=retry_if_exception_type(Exception), # Retry on any exceptions
)
async def stop_runpod_container(pod_id, terminate_flag=False):
"""
Stop the RunPod container and confirm it is fully stopped.
If stopping fails and terminate_flag is True, attempt to terminate the container.
Args:
pod_id (str): The ID of the pod to stop.
terminate_flag (bool): If True, attempt to terminate the pod if stopping fails.
Raises:
RuntimeError: If the pod fails to stop or terminate successfully.
"""
logger.info(f"Initiating shutdown for pod with ID: {pod_id}")
start_time = datetime.now(timezone.utc)
# Attempt to stop the pod
try:
response = runpod.stop_pod(pod_id)
logger.debug(f"RunPod stop response: {response}")
# Ensure the response matches the expected pod ID
if response.get("id") != pod_id:
logger.warning(f"Unexpected pod ID in stop response: {response.get('id')}")
# Check if the pod has stopped
pod_details = runpod.get_pod(pod_id)
current_status = pod_details.get("desiredStatus")
logger.debug(f"Pod status is '{current_status}'.")
if current_status not in ["EXITED", "STOPPED"]:
raise RuntimeError(f"Pod with ID {pod_id} is still in state '{current_status}'. Retrying...")
end_time = datetime.now(timezone.utc)
shutdown_duration = (end_time - start_time).total_seconds()
logger.info(f"Stopped pod with ID: {pod_id}. Container shutdown took {shutdown_duration:.2f} seconds.")
return # Exit function if the pod is successfully stopped
except Exception as stop_exception:
logger.warning(f"Failed to stop pod {pod_id}.")
# Check if termination is enabled
if terminate_flag:
logger.info(f"Terminate flag is set. Attempting to terminate pod {pod_id}...")
try:
terminate_response = runpod.terminate_pod(pod_id)
logger.info(f"RunPod terminate response: {terminate_response}")
# Verify termination status
terminated_pod = runpod.get_pod(pod_id)
if terminated_pod.get("desiredStatus") == "TERMINATED":
logger.info(f"Pod with ID {pod_id} has been successfully terminated.")
return
else:
raise RuntimeError(f"Pod with ID {pod_id} did not terminate successfully.")
except Exception as terminate_exception:
logger.exception(f"Failed to terminate pod {pod_id}. Termination error: {terminate_exception}")
raise RuntimeError(f"Failed to stop or terminate pod {pod_id}: {stop_exception}, {terminate_exception}")
else:
logger.error(f"Terminate flag is not set. Pod {pod_id} remains in an unresolved state.")
raise RuntimeError(f"Failed to stop pod {pod_id}, and termination was not attempted.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment