High-performance batch inference for large language models, powered by Ray Data.
Ray Data LLM provides an efficient, scalable solution for batch processing LLM inference workloads with:
- High Throughput: Optimized performance using vLLM's paged attention and continuous batching
- Distributed Processing: Scale across multiple GPUs and machines using Ray Data
- Smart Batching: Automatic request batching and scheduling for optimal throughput
- Prefix Caching: Memory-efficient processing of prompts with shared prefixes
- Flexible Workloads: Support for raw prompts, chat completions, and multimodal inputs
- LoRA Support: Dynamic loading of LoRA adapters from HuggingFace, local paths, or S3
from vllm import SamplingParams
from ray.data.llm import (
VLLMBatchInferencer,
vLLMConfig
)
# Initialize processor with model/infrastructure config
processor = VLLMBatchInferencer(
model="meta-llama/Llama-2-70b-chat-hf",
vllm_kwargs=dict(
num_workers=4,
tensor_parallel_size=2,
gpu_memory_utilization=0.95,
enable_prefix_caching=True,
)
)
# Process dataset
ds = ray.data.read_parquet("s3://my-bucket/questions.parquet")
ds = processor.transform(
ds,
input_column="question",
output_column="answer",
accelerator_type="L40S",
sampling_params=SamplingParams(
temperature=0.7,
max_tokens=512,
top_p=0.95
)
)
@dataclass
# TODO: Are there vLLM engine args we should not expose?
class vLLMEngineConfig:
"""Configuration for vLLM engine."""
model: str
lora_config: Optional[LoRAConfig] = None
vllm_kwargs: Dict[str, Any] = field(default_factory=dict)
Process chat-style conversations:
from ray.data.llm import ChatConfig
ds = processor.transform(
ds,
input_column="query",
output_column="response",
prompt_config=ChatConfig(
system_prompt="You are a helpful customer service agent.",
template="Customer: {input}\nAgent:"
),
sampling_params=SamplingParams(
temperature=0.9,
max_tokens=1024
)
)
Optimize throughput for prompts with common prefixes:
from ray.data.llm import SharedPrefixConfig
ds = processor.transform(
ds,
input_column="problem",
output_column="solution",
prompt_config=SharedPrefixConfig(
prefix="Solve step by step:\n\n"
),
sampling_params=SamplingParams(max_tokens=512)
)
Process datasets with images (requires compatible models):
from ray.data.llm import MultimodalConfig
ds = processor.transform(
ds,
input_column="image_path",
output_column="description",
prompt_config=MultimodalConfig(
template="Describe this image in detail:\n{image}",
image_size=(512, 512) # Optional resizing
),
sampling_params=SamplingParams(max_tokens=256)
)
Enable fault-tolerant processing with checkpointing:
from ray.data.llm import CheckpointConfig
ds = processor.transform(
ds,
input_column="article",
output_column="summary",
sampling_params=SamplingParams(max_tokens=200),
checkpoint_config=CheckpointConfig(
path="s3://my-bucket/checkpoints/cnn-summary",
checkpoint_frequency=100,
resume_from_checkpoint=True,
cleanup_checkpoint=True
)
)
- Ray >= 2.6.0
- vLLM >= 0.2.0
- Python >= 3.8
- CUDA >= 11.8
All configuration classes provide complete type hints for better IDE support:
from ray.data.llm import (
LLMConfig,
SamplingParams,
ChatConfig,
SharedPrefixConfig,
MultimodalConfig,
CheckpointConfig,
LoRAConfig
)
MIT