Last active
September 22, 2023 07:35
-
-
Save betatim/94840c772380bd7db8b1d1d222c2187a to your computer and use it in GitHub Desktop.
Working out how to `mpirun` dask with cuda
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask_mpi import initialize | |
from dask import distributed | |
def dask_info(): | |
distributed.print("woah i'm running!") | |
distributed.print("ncores:", client.ncores()) | |
distributed.print() | |
distributed.print(client.scheduler_info()) | |
def square(x): | |
return x ** 2 | |
def neg(x): | |
return -x | |
if __name__ == "__main__": | |
# MPI Ranks 1-n will be used for the Dask scheduler and workers | |
# and will not progress beyond this initialization call | |
initialize(worker_class="dask_cuda.CUDAWorker", | |
worker_options={"enable_tcp_over_ucx": False, | |
"enable_infiniband": False, | |
"enable_nvlink": False,} | |
) | |
# MPI Rank 0 will continue executing the script once the scheduler has started | |
from dask.distributed import Client | |
client = Client() # The scheduler address is found automatically via MPI | |
client.wait_for_workers(2) | |
dask_info() | |
A = client.map(square, range(10)) | |
B = client.map(neg, A) | |
total = client.submit(sum, B) | |
distributed.print("total:", total.result()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment