Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created December 18, 2024 22:27
Show Gist options
  • Save vwxyzjn/dbb7fa9dbc775268c48fa92206bc5d08 to your computer and use it in GitHub Desktop.
Save vwxyzjn/dbb7fa9dbc775268c48fa92206bc5d08 to your computer and use it in GitHub Desktop.
# Taken and modified from https://github.com/huggingface/trl
# Copyright 2024 The AllenAI Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
from datetime import timedelta
from typing import Any, Optional, Union
import ray
import torch
import torch.distributed
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
from vllm.worker.worker import Worker
# Copy from pytorch to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: str = None,
pg_options: Optional[Any] = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
pg_options=pg_options,
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
class WorkerWrap(Worker):
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), "default torch process group must be initialized"
assert group_name != "", "group name must not be empty"
rank = torch.distributed.get_rank() + rank_offset
self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
print(
f"init_process_group: master_address={master_address}, master_port={master_port}, ",
f"rank={rank}, world_size={world_size}, group_name={group_name}",
)
def update_weight(self, name, dtype, shape, empty_cache=False):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
# print(f"update_weight: {name}, dtype: {dtype}, shape: {shape}, rank: {torch.distributed.get_rank()}, world_size: {torch.distributed.get_world_size()}")
# if torch.distributed.get_rank() == 0:
# print(f"update weight: {name}, dtype: {dtype}, shape: {shape}")
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
# TODO: should we empty cache if all weights have updated?
# if empty_cache:
# torch.cuda.empty_cache()
@ray.remote
class LLMRayActor:
def __init__(self, *args, **kwargs):
import vllm
self.__version__ = vllm.__version__
assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1"
self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1
# See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
if self.use_gpu_executor:
vllm.worker.worker.Worker = WorkerWrap
else:
# RayGPUExecutor
# See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
kwargs["worker_use_ray"] = True
if vllm.__version__ > "0.4.1":
RayWorkerWrapperPath = vllm.executor.ray_utils
else:
RayWorkerWrapperPath = vllm.engine.ray_utils
class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
def __init__(self, *args, **kwargs) -> None:
kwargs["worker_module_name"] = "open_instruct.vllm_utils2"
kwargs["worker_class_name"] = "WorkerWrap"
super().__init__(*args, **kwargs)
RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper
self.llm = vllm.LLM(*args, **kwargs)
def generate(self, *args, **kwargs):
return self.llm.generate(*args, **kwargs)
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.init_process_group(
master_address, master_port, rank_offset, world_size, group_name, backend
)
else:
return self.llm.llm_engine.model_executor._run_workers(
"init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend
)
def update_weight(self, name, dtype, shape, empty_cache=False):
self.stop_remote_worker_execution_loop()
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
else:
return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache)
def stop_remote_worker_execution_loop(self):
# Fix error for using 2 communication group
# https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4
if self.__version__ > "0.4.2":
self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()
def create_vllm_engines(
num_engines: int,
tensor_parallel_size: int,
pretrain: str,
revision: str,
seed: int,
enable_prefix_caching: bool,
max_model_len: int,
):
vllm_engines = []
for i in range(num_engines):
# When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it.
num_gpus = int(tensor_parallel_size == 1)
scheduling_strategy = None
if tensor_parallel_size > 1:
bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
pg = placement_group(bundles)
ray.get(pg.ready())
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
)
print(f"vllm: {num_gpus=}, {num_engines=}")
vllm_engines.append(
LLMRayActor.options(
num_cpus=1,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
pretrain,
revision=revision,
tokenizer_revision=revision,
trust_remote_code=True,
tensor_parallel_size=tensor_parallel_size,
dtype="bfloat16",
seed=seed + i,
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len,
)
)
return vllm_engines
if __name__ == "__main__":
llm = LLMRayActor.remote("meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=2)
output = ray.get(llm.generate.remote("San Franciso is a"))
print(f"output: {output}")
# Taken and modified from https://github.com/huggingface/trl
# Copyright 2024 The AllenAI Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
import socket
from transformers import (
AutoModelForCausalLM,
)
import ray
import torch
import torch.distributed
from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group
if __name__ == "__main__":
vllm_tensor_parallel_size = 2
vllm_num_engines = 1
vllm_sync_backend = "nccl"
model_name_or_path = "allenai/Llama-3.1-Tulu-3-8B-DPO"
model_name_or_path2 = "allenai/Llama-3.1-Tulu-3-8B"
# llm = LLMRayActor.remote("meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=2)
# output = ray.get(llm.generate.remote("San Franciso is a"))
# print(f"output: {output}")
vllm_engines = create_vllm_engines(
vllm_num_engines,
vllm_tensor_parallel_size,
model_name_or_path,
None,
1,
False,
4096,
)
master_address = ray._private.services.get_node_ip_address()
with socket.socket() as sock:
sock.bind(("", 0))
master_port = sock.getsockname()[1]
vllm_num_engines, vllm_tensor_parallel_size = (
vllm_num_engines,
vllm_tensor_parallel_size,
)
world_size = vllm_num_engines * vllm_tensor_parallel_size + 1
backend = vllm_sync_backend
# https://github.com/OpenRLHF/OpenRLHF/issues/313
# if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0":
# backend = "gloo"
# print(
# "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)"
# )
refs = [
engine.init_process_group.remote(
master_address,
master_port,
i * vllm_tensor_parallel_size + 1,
world_size,
"openrlhf",
backend=backend,
)
for i, engine in enumerate(vllm_engines)
]
model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=0,
group_name="openrlhf",
)
ray.get(refs)
torch.set_default_device("cuda:7")
model = AutoModelForCausalLM.from_pretrained(model_name_or_path2, torch_dtype=torch.bfloat16)
model = model.to("cuda:7")
def broadcast_to_vllm():
# avoid OOM
torch.cuda.empty_cache()
count, num_params = 0, len(list(model.named_parameters()))
refss = []
for name, param in model.named_parameters():
count += 1
shape = param.shape
refs = [
engine.update_weight.remote(
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
)
for engine in vllm_engines
]
refss.extend(refs)
torch.distributed.broadcast(param.data, 0, group=model_update_group)
ray.get(refss)
broadcast_to_vllm()
print("broadcasted model to vllm")
@vwxyzjn
Copy link
Author

vwxyzjn commented Dec 19, 2024

Tested with and it hangs

    vllm_tensor_parallel_size = 1
    vllm_num_engines = 1
    vllm_sync_backend = "nccl"

image

but gloo works fine

    vllm_tensor_parallel_size = 1
    vllm_num_engines = 1
    vllm_sync_backend = "gloo"

image

@youkaichao figured it out in the end: we just need to set export NCCL_CUMEM_ENABLE=0 and the nccl backend just worked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment