Created
January 23, 2019 07:10
-
-
Save definitelyuncertain/3e340ea8fa4907e3ce3a0f0b42a454c8 to your computer and use it in GitHub Desktop.
Using ray with MPI: A simple counter object whose value is incremented in succession by 4 MPI processes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run as: mpirun -np <# of threads> python test-ray.py | |
from mpi4py import MPI | |
import os, sys, time | |
import numpy as np | |
import ray | |
@ray.remote | |
class Counter(object): | |
def __init__(self): | |
self.val = 0 | |
def increment(self): | |
self.val += 1 | |
def get_val(self): | |
return self.val | |
comm =MPI.COMM_WORLD | |
rank = comm.Get_rank() | |
redis_add, root_redis_add = None, None | |
if rank==0: | |
root_redis_add = ray.init()['redis_address'] | |
comm.Barrier() | |
redis_add = comm.bcast(root_redis_add,root=0) | |
if rank != 0: | |
print(ray.init(redis_address=redis_add)) | |
comm.Barrier() | |
rc_id = None | |
if rank==0 : | |
rc_id = Counter.remote() | |
print('R0 ID', rc_id) | |
print('R%d Barrier'%(rank)) | |
comm.Barrier() | |
rc_id_remote = comm.bcast(rc_id,root=0) | |
print('R%d ID'%(rank), rc_id_remote) | |
rc = rc_id_remote | |
time.sleep(float(rank)+1) | |
print('R%d val before'%(rank), ray.get(rc.get_val.remote())) | |
ray.get(rc.increment.remote()) | |
print('R%d val after'%(rank), ray.get(rc.get_val.remote())) | |
print('R%d Barrier2'%(rank)) | |
comm.Barrier() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment