Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Last active December 21, 2024 00:31
Show Gist options
  • Save richardliaw/90b9535dc864825a168e10a9a7438efb to your computer and use it in GitHub Desktop.
Save richardliaw/90b9535dc864825a168e10a9a7438efb to your computer and use it in GitHub Desktop.

Ray Data LLM

High-performance batch inference for large language models, powered by Ray Data.

Overview

Ray Data LLM provides an efficient, scalable solution for batch processing LLM inference workloads with:

  • High Throughput: Optimized performance using vLLM's paged attention and continuous batching
  • Distributed Processing: Scale across multiple GPUs and machines using Ray Data
  • Smart Batching: Automatic request batching and scheduling for optimal throughput
  • Prefix Caching: Memory-efficient processing of prompts with shared prefixes
  • Flexible Workloads: Support for raw prompts, chat completions, and multimodal inputs
  • LoRA Support: Dynamic loading of LoRA adapters from HuggingFace, local paths, or S3

Quick Start

from vllm import SamplingParams
from ray.data.llm import (
    VLLMBatchInferencer,
    vLLMConfig
)

# Initialize processor with model/infrastructure config
processor = VLLMBatchInferencer(
    model="meta-llama/Llama-2-70b-chat-hf",
    vllm_kwargs=dict(
        num_workers=4,
        tensor_parallel_size=2,
        gpu_memory_utilization=0.95,
        enable_prefix_caching=True,
    )
)

# Process dataset
ds = ray.data.read_parquet("s3://my-bucket/questions.parquet")
ds = processor.transform(
    ds,
    input_column="question",
    output_column="answer",
    accelerator_type="L40S",
    sampling_params=SamplingParams(
        temperature=0.7,
        max_tokens=512,
        top_p=0.95
    )
)

Configuration Classes

Model Configuration

@dataclass
# TODO: Are there vLLM engine args we should not expose?
class vLLMEngineConfig:
    """Configuration for vLLM engine."""
    model: str
    lora_config: Optional[LoRAConfig] = None
    vllm_kwargs: Dict[str, Any] = field(default_factory=dict)

Advanced Usage

Chat Completions

Process chat-style conversations:

from ray.data.llm import ChatConfig

ds = processor.transform(
    ds,
    input_column="query",
    output_column="response",
    prompt_config=ChatConfig(
        system_prompt="You are a helpful customer service agent.",
        template="Customer: {input}\nAgent:"
    ),
    sampling_params=SamplingParams(
        temperature=0.9,
        max_tokens=1024
    )
)

Shared Prefix Optimization

Optimize throughput for prompts with common prefixes:

from ray.data.llm import SharedPrefixConfig

ds = processor.transform(
    ds,
    input_column="problem",
    output_column="solution",
    prompt_config=SharedPrefixConfig(
        prefix="Solve step by step:\n\n"
    ),
    sampling_params=SamplingParams(max_tokens=512)
)

Multimodal Processing

Process datasets with images (requires compatible models):

from ray.data.llm import MultimodalConfig

ds = processor.transform(
    ds,
    input_column="image_path",
    output_column="description",
    prompt_config=MultimodalConfig(
        template="Describe this image in detail:\n{image}",
        image_size=(512, 512)  # Optional resizing
    ),
    sampling_params=SamplingParams(max_tokens=256)
)

Checkpointing (Anyscale Only)

Enable fault-tolerant processing with checkpointing:

from ray.data.llm import CheckpointConfig

ds = processor.transform(
    ds,
    input_column="article",
    output_column="summary",
    sampling_params=SamplingParams(max_tokens=200),
    checkpoint_config=CheckpointConfig(
        path="s3://my-bucket/checkpoints/cnn-summary",
        checkpoint_frequency=100,
        resume_from_checkpoint=True,
        cleanup_checkpoint=True
    )
)

Requirements

  • Ray >= 2.6.0
  • vLLM >= 0.2.0
  • Python >= 3.8
  • CUDA >= 11.8

Type Hints

All configuration classes provide complete type hints for better IDE support:

from ray.data.llm import (
    LLMConfig,
    SamplingParams,
    ChatConfig,
    SharedPrefixConfig,
    MultimodalConfig,
    CheckpointConfig,
    LoRAConfig
)

License

MIT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment