Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active September 28, 2023 08:18
Show Gist options
  • Save norabelrose/fabab4651eb50001883ed2148c1ac346 to your computer and use it in GitHub Desktop.
Save norabelrose/fabab4651eb50001883ed2148c1ac346 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from functools import partial
from itertools import cycle
import logging
import multiprocessing as std_mp
import socket
import warnings
import dill
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from datasets import Dataset
from torch.distributed.fsdp import (
CPUOffload, FullyShardedDataParallel as FSDP
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from typing import Any, Callable, Iterable, Type, cast
from elk.utils import (
instantiate_model, pytree_map, select_usable_devices
)
@dataclass
class InferenceServer:
"""High-level interface for running inference on a model on multiple GPUs.
This is basically a glorified `multiprocessing.Pool`. The only difference is that
each worker maintains a copy of the model on a dedicated GPU.
"""
model_str: str
num_workers: int = -1
cpu_offload: bool = False
fsdp: bool = False
def __post_init__(self):
self._current_id = 0
self._process_ctx: mp.ProcessContext | None = None
self._result_queues = []
self._task_queues = []
@property
def running(self) -> bool:
"""Whether the server is running."""
return self._process_ctx is not None
def start(self) -> None:
"""Spin up the workers."""
if self._process_ctx is not None:
raise RuntimeError("The server is already running")
# Load the model on the main process, then zero-copy share it with the workers.
# This ensures that we don't copy the model num_workers times on the CPU and
# run out of RAM for large models
print("Loading model...")
model = instantiate_model(self.model_str, torch_dtype="auto")
model.share_memory()
model_size = sum(
p.numel() * p.element_size() for p in model.parameters()
)
# Determine which GPUs we can use
devices = select_usable_devices(
self.num_workers, min_memory=model_size if not self.fsdp else None
)
self.num_workers = len(devices) # This may have been -1 before
fsdp_port, wrap_policy = None, None
if self.fsdp:
fsdp_port = find_available_port()
msg = f"Fully Sharded Data Parallel running on port {fsdp_port}"
layer_cls = get_transformer_layer_cls(model)
if layer_cls is not None:
msg += f" with '{layer_cls.__name__}' wrapping policy"
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={layer_cls}
)
print(msg)
self._manager = mp.Manager()
self._result_queues = [
self._manager.Queue() for _ in range(self.num_workers)
]
self._task_queues = [
self._manager.Queue() for _ in range(self.num_workers)
]
self._process_ctx = mp.spawn(
_worker_wrapper,
args=(
devices,
model,
self._task_queues,
self._result_queues,
self.cpu_offload,
fsdp_port,
wrap_policy
),
join=False,
nprocs=self.num_workers
)
def shutdown(self) -> bool:
"""Shut down all the workers, returning `True` if successful."""
if self._process_ctx is None:
raise RuntimeError("Can't shut down a server that isn't running")
# Let the workers know that they should shut down
for q in self._task_queues:
try:
q.put_nowait(None)
except std_mp.queues.Empty: # type: ignore[attr-defined]
pass
self._manager.shutdown()
return self._process_ctx.join()
# Support use as a context manager, just like mp.Pool
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
def map(
self,
closure: Callable[[ModelOutput], Any],
dataset: Dataset,
) -> list:
"""Run inference on the given inputs, running a closure on the outputs."""
return list(self.imap(closure, dataset))
def imap(
self,
closure: Callable[[ModelOutput], Any] | None,
dataset: Dataset,
) -> Iterable:
"""Run inference on the given inputs, running a closure on the outputs."""
if self._process_ctx is None:
raise RuntimeError("Can't run inference on a server that isn't running")
# We need PyTorch tensors
dataset = dataset.with_format("torch")
# Pickle the closure and send it to the workers
closure_pkl = dill.dumps(closure)
shards = [dataset.shard(self.num_workers, i) for i in range(self.num_workers)]
for q, shard in zip(self._task_queues, shards):
q.put((closure_pkl, shard))
yield from round_robin(self._result_queues) # type: ignore[return-value]
def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None:
"""Get the class of the transformer layer used by the given model."""
total_params = sum(p.numel() for p in model.parameters())
for module in model.modules():
if isinstance(module, torch.nn.ModuleList):
module_params = sum(p.numel() for p in module.parameters())
if module_params > total_params / 2:
return type(module[0])
return None
def get_socket_with_port() -> socket.socket:
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError:
s.close()
raise RuntimeError("Failed to create a socket")
def find_available_port() -> int:
s = get_socket_with_port()
_, port, *_ = s.getsockname()
s.close()
return port
def round_robin(queues: list[mp.Queue], sentinel: Any = None) -> Iterable[Any]:
"""Yield items from the given queues in round-robin order."""
exhausted = set()
for idx, q in cycle(enumerate(queues)):
if len(exhausted) == len(queues):
break
if idx in exhausted:
continue
try:
item = q.get(timeout=0.01)
except std_mp.queues.Empty: # type: ignore[attr-defined]
pass
else:
if item == sentinel:
exhausted.add(idx)
else:
yield item
@torch.inference_mode()
def _worker(
rank: int,
devices: list[str],
model: PreTrainedModel,
qs: list[mp.Queue],
out_qs: list[mp.Queue],
cpu_offload: bool = False,
fsdp_port: int | None = None,
wrap_policy: partial[bool] | None = None,
):
"""Worker process that maintains a copy of the model on a dedicated GPU."""
# Prevent duplicate logging messages
if rank != 0:
logging.disable(logging.CRITICAL)
warnings.filterwarnings("ignore")
closure: Callable[[ModelOutput], Any] | None = None
dataset: Dataset | None = None
device = devices[rank]
# Fully Sharded Data Parallel for large models
if fsdp_port is not None:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(fsdp_port)
dist.init_process_group("nccl", rank=rank, world_size=len(devices))
torch.cuda.set_device(device)
wrapped = FSDP(
model,
auto_wrap_policy=wrap_policy,
cpu_offload=CPUOffload(offload_params=False),
device_id=torch.device(device),
forward_prefetch=True,
)
model = cast(PreTrainedModel, wrapped)
else:
model.to(device)
# Breaks when x is None, the sentinel value indicating we should shut down
in_queue = qs[rank]
out_queue = out_qs[rank]
while msg := in_queue.get():
# Someone called map() giving us a new closure and dataset to use
assert isinstance(msg, tuple) and len(msg) == 2
closure_pkl, dataset = msg
closure = dill.loads(closure_pkl)
assert dataset is not None
for record in dataset:
assert isinstance(record, dict)
inputs_cuda = pytree_map(lambda v: v.to(device).unsqueeze(0), record)
# We always want to return the hidden states
outputs = model(**inputs_cuda, output_hidden_states=True)
if callable(closure):
outputs = closure(outputs)
else:
# Move the outputs back to the CPU
outputs_cls = type(outputs)
outputs_dict = pytree_map(
lambda x: x.cpu().share_memory_(), outputs
)
outputs = outputs_cls(**outputs_dict)
# Send the outputs back to the main process
out_queue.put(outputs)
# Indicate we're done with this dataset
out_queue.put(None)
# Clean up the FSDP process group
if fsdp_port is not None:
dist.destroy_process_group()
def _worker_wrapper(rank: int, *args):
try:
return _worker(rank, *args)
except Exception as e:
print(f"Exception in worker {rank}: {e}")
raise e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment