Skip to content

Instantly share code, notes, and snippets.

@flishwang
Created July 4, 2025 04:14
Show Gist options
  • Save flishwang/d06fde477d3c0d39ce51e94532121aa6 to your computer and use it in GitHub Desktop.
Save flishwang/d06fde477d3c0d39ce51e94532121aa6 to your computer and use it in GitHub Desktop.
vllm_monkey_patch
import torch
from vllm.v1.worker.gpu_worker import Worker, logger
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.utils import GiB_bytes
from typing import Optional
import gc
from vllm.device_allocator.cumem import libcudart, is_pin_memory_available, unmap_and_release
def custom_load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
if not hasattr(self, 'model_id'):
model_id = getattr(allocator, 'allocated_model_ids', 0) + 1
allocator.allocated_model_ids = model_id
self.model_id = model_id
tag_prefix = str(getattr(self,'model_id', self.model_config.model)) + ':'
#print(f'{self.model_config.model}: using memory pool ' + tag_prefix + "weights")
context = allocator.use_memory_pool(tag=tag_prefix + "weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
def custom_initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
if not hasattr(self, 'model_id'):
model_id = getattr(allocator, 'allocated_model_ids', 0) + 1
allocator.allocated_model_ids = model_id
self.model_id = model_id
tag_prefix = str(getattr(self,'model_id', self.model_config.model)) + ':'
context = allocator.use_memory_pool(tag=tag_prefix + "kv_cache")
#print(f'{self.model_config.model}: using memory pool ' + tag_prefix + "kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def custom_sleep(self, level: int = 1) -> None:
tag_prefix = str(getattr(self,'model_id', self.model_config.model)) + ':'
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone()
for name, buffer in model.named_buffers()
}
allocator = CuMemAllocator.get_instance()
#print(f'{self.model_config.model}: {level=}, offload memory pool ' + tag_prefix + "weights")
offload_tags=(tag_prefix + "weights",) if level == 1 else tuple()
for ptr, data in allocator.pointer_to_data.items():
handle = data.handle
if data.tag in offload_tags:
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device='cpu',
pin_memory=is_pin_memory_available())
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
data.cpu_backup_tensor = cpu_backup_tensor
if data.tag.startswith(tag_prefix):
unmap_and_release(handle)
gc.collect()
torch.cuda.empty_cache()
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
f"{tag_prefix=}, Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def custom_wake_up(self, tags: Optional[list[str]] = None) -> None:
tag_prefix = str(getattr(self,'model_id', self.model_config.model)) + ':'
allocator = CuMemAllocator.get_instance()
if tags is None:
tags = []
for ptr, data in allocator.pointer_to_data.items():
tag = data.tag
if tag.startswith(tag_prefix):
tags.append(tag)
else:
tags = [tag_prefix + tag for tag in tags]
#print(f'{self.model_config.model}: waking {tags}')
allocator.wake_up(tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
@torch.inference_mode()
def custom_determine_available_memory(self) -> int:
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
_, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
free_gpu_memory, _ = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
# due to torch/vllm bug, the freed memory in custom memory_pools will not be included in torch_allocated_bytes.
# Therefore, we need to remove this part of memory from peak memory.
print(f'{total_gpu_memory/1024**3=:.2f} {peak_memory/1024**3=:.2f} {non_torch_allocations/1024**3=:.2f}')
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
return int(available_kv_cache_memory)
Worker.load_model = custom_load_model
Worker.sleep = custom_sleep
Worker.wake_up = custom_wake_up
Worker.initialize_from_config = custom_initialize_from_config
Worker.determine_available_memory = custom_determine_available_memory
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment