Created
September 19, 2024 01:49
-
-
Save dtsaras/e882d847d59f96aff468316d5750b6b1 to your computer and use it in GitHub Desktop.
Distributed training with broadcasting the updated weights
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import time | |
import torch | |
import torch.distributed.rpc as rpc | |
import tqdm | |
from tensordict import TensorDict | |
import torch.distributed as dist | |
from torchrl._utils import accept_remote_rref_invocation, logger as torchrl_logger | |
from torchrl.data.replay_buffers import RemoteReplayBuffer | |
from torchrl.data.replay_buffers.samplers import SliceSampler | |
from torchrl.data.replay_buffers.storages import LazyMemmapStorage | |
from torchrl.data.replay_buffers.writers import RoundRobinWriter | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import logging | |
from torch import multiprocessing as mp | |
RETRY_LIMIT = 2 | |
RETRY_DELAY_SECS = 10 | |
REPLAY_BUFFER_NODE = "ReplayBuffer" | |
TRAINER_NODE = "Trainer" | |
class ReplayBufferNode(RemoteReplayBuffer): | |
"""Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteReplayBuffer` | |
means all of its public methods are remotely invokable using `torch.rpc`. | |
Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation | |
cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. | |
Args: | |
capacity (int): the maximum number of elements that can be stored in the replay buffer. | |
""" | |
def __init__(self, capacity: int): | |
super().__init__( | |
storage=LazyMemmapStorage( | |
max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu") | |
), | |
sampler=SliceSampler(num_slices=4), | |
writer=RoundRobinWriter(), | |
batch_size=32, | |
) | |
class CollectorNode: | |
"""Data collector node responsible for collecting experiences used for learning. | |
Args: | |
replay_buffer (rpc.RRef): the RRef associated with the construction of the replay buffer | |
frames_per_batch (int): the ``frames_per_batch`` of the collector. This serves as an example of hyperparameters | |
to be passed to the collector. | |
""" | |
def __init__( | |
self, | |
replay_buffer: rpc.RRef, | |
frames_per_batch: int = 128 | |
) -> None: | |
self.id = rpc.get_worker_info().id | |
self.replay_buffer = replay_buffer | |
# Write your collector here | |
# self.collector = SyncDataCollector(...) | |
assert frames_per_batch > 0 | |
self.frames_per_batch = frames_per_batch | |
self.model = TestModel() | |
self.stop_collect_flag = False | |
@accept_remote_rref_invocation | |
def stop_collect(self): | |
# Stop the data collection process | |
self.stop_collect_flag = True | |
def _submit_item_async(self) -> rpc.RRef: | |
"""Function that collects data and populates the replay buffer.""" | |
# Replace this by a call to next() over the data collector | |
done = torch.zeros(self.frames_per_batch, 1, dtype=torch.bool) | |
done[..., -1, 0] = True | |
td = TensorDict( | |
{ | |
"action": torch.randint( | |
100, | |
( | |
self.frames_per_batch, | |
1, | |
), | |
), | |
"done": torch.zeros(self.frames_per_batch, dtype=torch.bool), | |
"observation": torch.randn(self.frames_per_batch, 4), | |
"step_count": torch.arange(self.frames_per_batch), | |
"terminated": torch.zeros(self.frames_per_batch, dtype=torch.bool), | |
"truncated": torch.zeros(self.frames_per_batch, dtype=torch.bool), | |
"next": { | |
"done": done, | |
"observation": torch.randn(self.frames_per_batch, 4), | |
"reward": torch.randn(self.frames_per_batch, 1), | |
"step_count": torch.arange(1, self.frames_per_batch + 1), | |
"terminated": torch.zeros_like(done), | |
"truncated": done, | |
}, | |
}, | |
[self.frames_per_batch], | |
) | |
return rpc.remote( | |
self.replay_buffer.owner(), | |
ReplayBufferNode.extend, | |
args=( | |
self.replay_buffer, | |
td, | |
), | |
) | |
@accept_remote_rref_invocation | |
def _init_parameter_sharing(self): | |
torchrl_logger.info(f"[RANK{self.id}] Data Collector Init process group") | |
self.group = dist.init_process_group( | |
backend='gloo', | |
init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT_NCCL"]}', | |
rank=rpc.get_worker_info().id, | |
world_size=int(os.environ["WORLD_SIZE"]) - 1 | |
) | |
torchrl_logger.info(f"[RANK{self.id}] Finished Collector Init Process Group") | |
@accept_remote_rref_invocation | |
def _receive_weights(self, block=False): | |
for param in self.model.parameters(): | |
dist.broadcast(param.data, 0, async_op=True, group=self.group) | |
if block: | |
dist.barrier(group=self.group) | |
# print(f"[RANK {self.id}] Weights:", self.model.fc.weight) | |
@accept_remote_rref_invocation | |
def collect(self): | |
"""Method that begins experience collection (we just generate random TensorDicts in this example). | |
`accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation | |
`rpc.RRef` is provided in place of the object reference. | |
""" | |
self.stop_collect_flag = False | |
for elem in range(1000): | |
# time.sleep(random.randint(1, 4)) | |
time.sleep(0.1) | |
if self.stop_collect_flag: | |
break | |
item = self._submit_item_async() | |
@accept_remote_rref_invocation | |
def cleanup(self): | |
dist.destroy_process_group(self.group) | |
class TrainerNode: | |
"""Trainer node responsible for learning from experiences sampled from an experience replay buffer.""" | |
def __init__(self, replay_buffer_node="ReplayBuffer") -> None: | |
self.replay_buffer_node = replay_buffer_node | |
self.world_size = int(os.environ["WORLD_SIZE"]) | |
self.id = rpc.get_worker_info().id | |
torchrl_logger.info("TrainerNode") | |
self.replay_buffer = self._create_replay_buffer() | |
self._create_data_collectors() | |
self._init_parameter_sharing() | |
self.model = TestModel() | |
self.share_weights(block=True) | |
self._launch_data_collectors() | |
torchrl_logger.info("Initialized Trainer") | |
def share_weights(self, block: bool = True): | |
# Broadcast the model weights to all data collectors in the group | |
for param in self.model.parameters(): | |
dist.broadcast(param.data, 0, async_op=True, group=self.group) | |
# Asynchrnously start the process of receiving weights | |
for collector, data_collector_info in zip( | |
self.data_collectors, self.data_collector_infos | |
): | |
rpc.rpc_async( | |
data_collector_info, | |
CollectorNode._receive_weights, | |
args=(collector, block), | |
) | |
# If blocking, wait for all data collectors to receive the weights | |
if block: | |
dist.barrier(group=self.group) | |
def train(self, iterations: int) -> None: | |
"""Write your training loop here.""" | |
for iteration in tqdm.tqdm(range(iterations)): | |
torchrl_logger.info(f"[{self.id}] Training Iteration: {iteration}") | |
# # Wait until the buffer has elements | |
while not rpc.rpc_sync( | |
self.replay_buffer.owner(), | |
ReplayBufferNode.__len__, | |
args=(self.replay_buffer,), | |
): | |
continue | |
batch = rpc.rpc_sync( | |
self.replay_buffer.owner(), | |
ReplayBufferNode.sample, | |
args=(self.replay_buffer, 16), | |
) | |
if (iteration + 1)%10 == 0: | |
torch.nn.init.uniform_(self.model.fc.weight) | |
# print("Trainer", iteration, self.model.module.fc.weight) | |
self.share_weights(block=True) | |
# torchrl_logger.info(f"[{self.id}] Sample Obtained Iteration: {iteration}") | |
# torchrl_logger.info(f"{batch}") | |
# Process the sample here: forward, backward, ... | |
def _create_replay_buffer(self) -> rpc.RRef: | |
def connect(): | |
replay_buffer_info = rpc.get_worker_info(self.replay_buffer_node) | |
buffer_rref = rpc.remote( | |
replay_buffer_info, ReplayBufferNode, args=(10000,) | |
) | |
torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}") | |
return buffer_rref | |
while True: | |
try: | |
return connect() | |
except Exception as e: | |
torchrl_logger.info(f"Failed to connect to replay buffer: {e}") | |
time.sleep(RETRY_DELAY_SECS) | |
def _create_data_collectors(self) -> None: | |
data_collector_number = self.world_size - 2 | |
self.data_collectors = [] | |
self.data_collector_infos = [] | |
# discover launched data collector nodes (with retry to allow collectors to dynamically join) | |
def connect(n, retry): | |
data_collector_info = rpc.get_worker_info( | |
f"DataCollector{n + 1}" # 1, 2, 3, ... | |
) | |
torchrl_logger.info( | |
f"Data collector info: {data_collector_info}-retry={retry}" | |
) | |
dc_ref = rpc.remote( | |
data_collector_info, | |
CollectorNode, | |
args=(self.replay_buffer,), | |
timeout=10, | |
) | |
torchrl_logger.info("Finished connecting") | |
self.data_collectors.append(dc_ref) | |
self.data_collector_infos.append(data_collector_info) | |
for n in range(data_collector_number): | |
for retry in range(RETRY_LIMIT): | |
torchrl_logger.info(f"Connecting to DataCollector{n+1}/{data_collector_number}, retry={retry}") | |
try: | |
connect(n, retry) | |
break | |
except Exception as e: | |
torchrl_logger.info( | |
f"Failed to connect to DataCollector{n+1} with {retry} retries (err={e})" | |
) | |
time.sleep(RETRY_DELAY_SECS) | |
else: | |
raise Exception | |
def _init_parameter_sharing(self): | |
for collector, data_collector_info in zip( | |
self.data_collectors, self.data_collector_infos | |
): | |
rpc.rpc_async( | |
data_collector_info, | |
CollectorNode._init_parameter_sharing, | |
args=(collector,), | |
) | |
self.group = dist.init_process_group( | |
backend='gloo', | |
init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT_NCCL"]}', | |
rank=rpc.get_worker_info().id, | |
world_size=int(os.environ["WORLD_SIZE"]) - 1, | |
) | |
def _launch_data_collectors(self) -> None: | |
for collector, data_collector_info in zip( | |
self.data_collectors, self.data_collector_infos | |
): | |
rpc.remote( | |
data_collector_info, | |
CollectorNode.collect, | |
args=(collector,), | |
) | |
def stop_collect(self): | |
for collector, data_collector_info in zip( | |
self.data_collectors, self.data_collector_infos | |
): | |
rpc.rpc_async( | |
data_collector_info, | |
CollectorNode.stop_collect, | |
args=(collector,), | |
) | |
def cleanup(self): | |
for collector, data_collector_info in zip( | |
self.data_collectors, self.data_collector_infos | |
): | |
rpc.rpc_async( | |
data_collector_info, | |
CollectorNode.cleanup, | |
args=(collector,), | |
) | |
dist.destroy_process_group(self.group) | |
class TestModel(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc = torch.nn.Linear(7000, 7000) | |
def forward(self, x): | |
return self.fc(x) | |
def main(rank, world_size, **tensorpipe_kwargs): | |
"""Dispatcher for the distributed workflow. | |
rank 0 will be assigned the TRAINER job, | |
rank 1 will be assigned the REPLAY BUFFER job, | |
rank 2 to world_size-1 will be assigned the COLLECTOR jobs. | |
""" | |
torchrl_logger.info(f"Rank: {rank}") | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT_RPC"] = "29500" | |
os.environ["MASTER_PORT_NCCL"] = "29501" | |
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" | |
# RPC initialization options | |
options = rpc.TensorPipeRpcBackendOptions( | |
num_worker_threads=16, **tensorpipe_kwargs | |
) | |
options.init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT_RPC']}" | |
if rank == 0: | |
# rank 0 is the trainer | |
torchrl_logger.info(f"Init RPC on {TRAINER_NODE}...") | |
rpc.init_rpc( | |
TRAINER_NODE, | |
rank=rank, | |
backend=rpc.BackendType.TENSORPIPE, | |
rpc_backend_options=options, | |
world_size=world_size, | |
) | |
torchrl_logger.info(f"Initialised {TRAINER_NODE}") | |
trainer = TrainerNode(replay_buffer_node=REPLAY_BUFFER_NODE) | |
trainer.train(100) | |
trainer.stop_collect() | |
trainer.cleanup() | |
rpc.shutdown() | |
elif rank == world_size - 1: | |
# rank world_size-1 is the replay buffer | |
# replay buffer waits passively for construction instructions from trainer node | |
torchrl_logger.info(f"Init RPC on {REPLAY_BUFFER_NODE}...") | |
rpc.init_rpc( | |
REPLAY_BUFFER_NODE, | |
rank=rank, | |
backend=rpc.BackendType.TENSORPIPE, | |
rpc_backend_options=options, | |
world_size=world_size, | |
) | |
torchrl_logger.info(f"Initialised {REPLAY_BUFFER_NODE}") | |
rpc.shutdown() | |
else: | |
# rank 2 to world_size-2 is a new data collector node | |
# data collectors also wait passively for construction instructions from trainer node | |
torchrl_logger.info(f"Init RPC on DataCollector{rank}") | |
rpc.init_rpc( | |
f"DataCollector{rank}", | |
rank=rank, | |
backend=rpc.BackendType.TENSORPIPE, | |
rpc_backend_options=options, | |
world_size=world_size, | |
) | |
torchrl_logger.info(f"Initialised DataCollector{rank}") | |
rpc.shutdown() | |
print("exiting", rank) | |
if __name__ == "__main__": | |
ctx = mp.get_context("spawn") | |
procs = [] | |
world_size = 5 | |
os.environ["WORLD_SIZE"] = str(world_size) | |
for i in range(world_size): | |
procs.append(ctx.Process(target=main, args=(i, world_size))) | |
procs[-1].start() | |
for p in reversed(procs): | |
p.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment