Skip to content

Instantly share code, notes, and snippets.

@ahmadia
Created February 20, 2014 03:24
Show Gist options
  • Save ahmadia/9106578 to your computer and use it in GitHub Desktop.
Save ahmadia/9106578 to your computer and use it in GitHub Desktop.
""" Group Tutorial
Tutorial code for assigning tasks to subcommunicators and relaying
results back.
In this tutorial, no intercommunicator is used. Instead,
intracommunicators are created with coordination done by rank 0. It
is assumed that the second process will be responsible for Task A, and
that all remaining processes coordinate on Task B.
Task A is a fake task that simply adds 1 to any input array.
Task B is a fake task that takes the collective sum of all input
arrays in its communicator (using reduce), then returns this sum as a
single float.
"""
from __future__ import print_function
from mpi4py import MPI
import numpy as np
# switch to unbuffered output (Don't do this for lots of prints!)
import sys
import os
sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)
world = MPI.COMM_WORLD
num_procs = world.size
my_rank = world.rank
if num_procs < 3:
raise Exception("This demonstration only works with at least 3 \
processors (mpirun -n 3 python group_demo.py)!")
def sprint(s):
"""Synchronize printing, and only from process 0
Also use whitespace and '=' to make more visually pleasing
"""
world.Barrier()
if my_rank == 0:
print('\n' + '='*len(s) + '\n' + s + '\n' + '='*len(s) + '\n')
world.Barrier()
def rprint(s):
"""Prepend process rank to print statement"""
print("[{:d}] ".format(my_rank) + s)
def taskA(in_array, out_array):
"""Fake Task A"""
rprint("task A called with input" + str(in_array))
# note the use of slice assignment here to copy memory instead of
# obtaining a view
out_array[:] = in_array + 1
def taskB(my_comm, in_array):
"""Fake Task B"""
rprint("task B called with input" + str(in_array))
in_buf = np.sum(in_array)
out_buf = np.zeros(1)
#Reduce(self, sendbuf, recvbuf, Op op=SUM, int root=0)
if my_comm.rank == 0:
my_comm.Reduce(in_buf, out_buf, op=MPI.SUM, root=0)
else:
my_comm.Reduce(in_buf, None, op=MPI.SUM, root=0)
return out_buf
def split_teams():
"""Create three non-overlapping subcommunicators from world.
Note that world is still usable after we
perform the split.
Comm 0: contains process 0
Comm 1: contains process 1
Comm 2: contains all remaining processes
The process that calls this function will return the appropriate
subcommunicator based on its rank.
"""
# Comm.Split(self, int color=0, int key=0)
# Processes with the same color are in the same new communicator
# key determines rank in new communicator, but isn't used in
# this exercise.
if my_rank == 0:
my_color = 0
elif my_rank == 1:
my_color = 1
else:
my_color = 2
return world.Split(color=my_color), my_color
sprint("Startup Phase")
rprint("MPI.COMM_WORLD contains {:d} processes".format(num_procs))
sprint("Splitting Teams")
my_comm, my_color = split_teams()
rprint("Assigned to {:d}".format(my_color))
sprint("Local subcommunicator information")
rprint("My subcommunicator with color {:d} contains {:d} \
processes".format(my_color, my_comm.size))
sprint("Send A input from process 0 to process 1")
if my_rank == 0:
a_data = np.ones(5)
world.send(a_data, dest=1)
# note that the rank 0 posts a receive before 1 sends results back
a_result = world.recv(source=1)
elif my_rank == 1:
a_data = world.recv(source=0)
# preallocating the result array here
a_result = np.empty(5)
# Task A is operated on by process 1
taskA(a_data, a_result)
# Send results back
world.send(a_result, dest=0)
sprint("Broadcast A results from 0 to all processes")
# This is usually more efficient than a sub-communicator broadcast
# Since this is a global operation on the communicator, EVERY process
# in world must invoke the bcast function call or this will hang
# indefinitely.
if my_rank == 0:
rprint("Broadcasting " + str(a_result))
b_data = world.bcast(a_result, root=0)
else:
# we only need to define the variable to be broadcast on the root process
b_data = world.bcast(None, root=0)
rprint("Received " + str(b_data))
sprint("Task B is operated on by remaining processes")
if my_color == 2:
b_result = taskB(my_comm, b_data)
sprint("Get results back from root on Comm B")
# This is guaranteed to be process 2
if my_rank == 0:
b_result = world.recv(source=2)
rprint("Received " + str(b_result))
elif my_rank == 2:
world.send(b_result, dest=0)
rprint("Sending " + str(b_result))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment