Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created December 21, 2024 00:14
Show Gist options
  • Save richardliaw/0396a025836cf94b8fea765482e07b5f to your computer and use it in GitHub Desktop.
Save richardliaw/0396a025836cf94b8fea765482e07b5f to your computer and use it in GitHub Desktop.

Ray Data LLM

High-performance batch inference for large language models, powered by Ray Data.

Overview

Ray Data LLM provides an efficient, scalable solution for batch processing LLM inference workloads with:

  • High Throughput: Optimized performance using vLLM's paged attention and continuous batching
  • Distributed Processing: Scale across multiple GPUs and machines using Ray Data
  • Smart Batching: Automatic request batching and scheduling for optimal throughput
  • Prefix Caching: Memory-efficient processing of prompts with shared prefixes
  • Flexible Workloads: Support for raw prompts, chat completions, and multimodal inputs
  • LoRA Support: Dynamic loading of LoRA adapters from HuggingFace, local paths, or S3

Quick Start

from ray.data.llm import (
    VLLMBatchInferencer,
    LLMConfig,
    SamplingParams
)

# Initialize processor with model/infrastructure config
processor = VLLMBatchInferencer(
    model_config=LLMConfig(
        model_id="meta-llama/Llama-2-70b-chat-hf",
        num_workers=4,
        tensor_parallel_size=2,
        gpu_memory_utilization=0.95,
        enable_prefix_caching=True
    )
)

# Process dataset
ds = ray.data.read_parquet("s3://my-bucket/questions.parquet")
ds = processor.transform(
    ds,
    input_column="question",
    output_column="answer",
    sampling_params=SamplingParams(
        temperature=0.7,
        max_tokens=512,
        top_p=0.95
    )
)

Configuration Classes

Model Configuration

@dataclass
class LLMConfig:
    """Configuration for LLM model and infrastructure."""
    model_id: str
    num_workers: int = 1
    tensor_parallel_size: int = 1
    pipeline_parallel_size: int = 1
    gpu_memory_utilization: float = 0.95
    enable_prefix_caching: bool = False
    enforce_eager: bool = False
    lora_config: Optional[LoRAConfig] = None

Sampling Parameters

@dataclass
class SamplingParams:
    """Parameters for controlling text generation."""
    temperature: float = 1.0
    max_tokens: int = 512
    top_p: float = 1.0
    top_k: Optional[int] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    stop: Optional[List[str]] = None
    ignore_eos: bool = False

Advanced Usage

Chat Completions

Process chat-style conversations:

from ray.data.llm import ChatConfig

ds = processor.transform(
    ds,
    input_column="query",
    output_column="response",
    prompt_config=ChatConfig(
        system_prompt="You are a helpful customer service agent.",
        template="Customer: {input}\nAgent:"
    ),
    sampling_params=SamplingParams(
        temperature=0.9,
        max_tokens=1024
    )
)

Shared Prefix Optimization

Optimize throughput for prompts with common prefixes:

from ray.data.llm import SharedPrefixConfig

ds = processor.transform(
    ds,
    input_column="problem",
    output_column="solution",
    prompt_config=SharedPrefixConfig(
        prefix="Solve step by step:\n\n"
    ),
    sampling_params=SamplingParams(max_tokens=512)
)

Multimodal Processing

Process datasets with images (requires compatible models):

from ray.data.llm import MultimodalConfig

ds = processor.transform(
    ds,
    input_column="image_path",
    output_column="description",
    prompt_config=MultimodalConfig(
        template="Describe this image in detail:\n{image}",
        image_size=(512, 512)  # Optional resizing
    ),
    sampling_params=SamplingParams(max_tokens=256)
)

Checkpointing (Anyscale Only)

Enable fault-tolerant processing with checkpointing:

from ray.data.llm import CheckpointConfig

ds = processor.transform(
    ds,
    input_column="article",
    output_column="summary",
    sampling_params=SamplingParams(max_tokens=200),
    checkpoint_config=CheckpointConfig(
        path="s3://my-bucket/checkpoints/cnn-summary",
        checkpoint_frequency=100,
        resume_from_checkpoint=True,
        cleanup_checkpoint=True
    )
)

Performance Tips

GPU Configuration

  • Adjust tensor_parallel_size based on model size and GPU memory
  • Set num_workers based on available GPU count
  • Use pipeline_parallel_size for very large models
  • Monitor gpu_memory_utilization (default 0.95)

Batch Processing

  • Enable prefix_caching for workloads with common prefixes
  • Set enforce_eager=True for more predictable latency
  • Adjust max_tokens based on your use case
  • Use checkpointing for fault tolerance on long-running jobs

Requirements

  • Ray >= 2.6.0
  • vLLM >= 0.2.0
  • Python >= 3.8
  • CUDA >= 11.8

Type Hints

All configuration classes provide complete type hints for better IDE support:

from ray.data.llm import (
    LLMConfig,
    SamplingParams,
    ChatConfig,
    SharedPrefixConfig,
    MultimodalConfig,
    CheckpointConfig,
    LoRAConfig
)

Known Issues and Limitations

  1. Checkpointing support requires Ray Turbo 2.39+ and is Anyscale-only
  2. Image processing may require additional memory considerations
  3. Some vLLM features may not be available with certain attention backends

License

MIT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment