Skip to content

Instantly share code, notes, and snippets.

@definitelyuncertain
Created January 23, 2019 07:10
Show Gist options
  • Save definitelyuncertain/3e340ea8fa4907e3ce3a0f0b42a454c8 to your computer and use it in GitHub Desktop.
Save definitelyuncertain/3e340ea8fa4907e3ce3a0f0b42a454c8 to your computer and use it in GitHub Desktop.
Using ray with MPI: A simple counter object whose value is incremented in succession by 4 MPI processes
# Run as: mpirun -np <# of threads> python test-ray.py
from mpi4py import MPI
import os, sys, time
import numpy as np
import ray
@ray.remote
class Counter(object):
def __init__(self):
self.val = 0
def increment(self):
self.val += 1
def get_val(self):
return self.val
comm =MPI.COMM_WORLD
rank = comm.Get_rank()
redis_add, root_redis_add = None, None
if rank==0:
root_redis_add = ray.init()['redis_address']
comm.Barrier()
redis_add = comm.bcast(root_redis_add,root=0)
if rank != 0:
print(ray.init(redis_address=redis_add))
comm.Barrier()
rc_id = None
if rank==0 :
rc_id = Counter.remote()
print('R0 ID', rc_id)
print('R%d Barrier'%(rank))
comm.Barrier()
rc_id_remote = comm.bcast(rc_id,root=0)
print('R%d ID'%(rank), rc_id_remote)
rc = rc_id_remote
time.sleep(float(rank)+1)
print('R%d val before'%(rank), ray.get(rc.get_val.remote()))
ray.get(rc.increment.remote())
print('R%d val after'%(rank), ray.get(rc.get_val.remote()))
print('R%d Barrier2'%(rank))
comm.Barrier()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment