Skip to content

Instantly share code, notes, and snippets.

@anj1
Created January 6, 2025 03:19
Show Gist options
  • Save anj1/271b8371f0c6922fc7595942cb6a99d5 to your computer and use it in GitHub Desktop.
Save anj1/271b8371f0c6922fc7595942cb6a99d5 to your computer and use it in GitHub Desktop.
import torch
import torch.distributed as dist
def main(rank, world_size):
# Initialize distributed environment
dist.init_process_group(backend='gloo', init_method='env://', rank=rank, world_size=world_size)
# Determine neighbor ranks
prev_rank = (rank - 1) % world_size
next_rank = (rank + 1) % world_size
# Create some data to send
data_to_send = torch.tensor([rank], dtype=torch.float32)
# Initialize tensors for received data
received_from_prev = torch.empty_like(data_to_send)
# Non-blocking send and recv
send_req = dist.isend(tensor=data_to_send, dst=next_rank)
recv_req = dist.irecv(tensor=received_from_prev, src=prev_rank)
# Wait for both operations to complete
send_req.wait()
recv_req.wait()
# Now process the received data
print(f"Rank {rank}: Received {received_from_prev} from rank {prev_rank}")
# Cleanup
dist.destroy_process_group()
if __name__ == "__main__":
import os
import sys
from torch.multiprocessing import spawn
# Set environment variables for distributed setup
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
# Number of processes to spawn
world_size = 4 # Can be modified based on needs
# Spawn processes
spawn(main, args=(world_size,), nprocs=world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment