Created
February 20, 2024 22:09
-
-
Save adrn/d983bc02dcf2b57ef28163213d38cc5c to your computer and use it in GitHub Desktop.
MPI pool example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib | |
import h5py | |
import numpy as np | |
def worker(task): | |
# Do something with the task and return output data! | |
i, x, cache_file = task | |
return i, x**2 + 2, cache_file | |
def callback(result): | |
i, value, cache_file = result | |
with h5py.File(cache_file, "r+") as f: | |
f["data"][i] = value | |
def main(pool, overwrite): | |
cache_file = pathlib.Path("test-cache-file.hdf5") | |
# One thing you need to know off the bat is how many total tasks there will be: | |
N_tasks_total = 10_000 | |
# Set up the cache file to make sure it exists. You'll have to modify this so it | |
# creates the data structure you will be using: | |
if not cache_file.exists() or overwrite: | |
print("Cache file doesn't exist") | |
with h5py.File(cache_file, "w") as f: | |
# Fill the cached data with nan values: | |
f["data"] = np.full(N_tasks_total, np.nan) | |
with h5py.File(cache_file, "r") as f: | |
# This is an array of indices of tasks that have not yet been completed: | |
idx = np.where(np.logical_not(np.isfinite(f["data"])))[0] | |
# Make some fake input data from random numbers: | |
input_data = np.random.uniform(0, 100, size=N_tasks_total) | |
index_array = np.arange(N_tasks_total) | |
tasks = [(i, x, cache_file) for i, x in zip(index_array[idx], input_data[idx])] | |
for _ in pool.map(worker, tasks, callback=callback): | |
pass | |
with h5py.File(cache_file, "r") as f: | |
print(f["data"][:]) | |
if __name__ == "__main__": | |
import sys | |
from argparse import ArgumentParser | |
from schwimmbad.mpi import MPIPool | |
# Define parser object | |
parser = ArgumentParser() | |
parser.add_argument("--overwrite", default=False, action="store_true") | |
args = parser.parse_args() | |
with MPIPool() as pool: | |
main(pool=pool, overwrite=args.overwrite) | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment