Last active
August 26, 2021 19:12
-
-
Save gravicle/27f622a7d3c335657f6fd8925c37352c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import multiprocessing as mp | |
from tqdm import tqdm | |
import math | |
from .device import * | |
try: | |
tp.set_start_method('spawn') | |
except RuntimeError: | |
pass | |
N_CPU = cpu_count() | |
def parallel_map(task, iter, process_count: int = N_CPU, show_progress=False, debug=False): | |
"""Runs a map across cores | |
Args: | |
task (function): The task to run | |
iter ([any]): list objects mapped to task | |
process_count (int): Number of processes to spawn | |
Returns: | |
list of return values of task | |
""" | |
if process_count > len(iter): | |
process_count = len(iter) | |
if debug: | |
for item in tqdm(iter): | |
task(item) | |
else: | |
with mp.Pool(process_count) as p: | |
if show_progress: | |
result = list(tqdm(p.imap(task, iter), total=len(iter))) | |
else: | |
result = p.map(task, iter) | |
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing as mp | |
import torch | |
def cpu_count(): | |
return mp.cpu_count() | |
def gpu_count(): | |
return torch.cuda.device_count() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from torch import rand | |
from luma_utils.concurrency import * | |
def task(numbers): | |
return np.sum(numbers) | |
def run(): | |
num_shards = [np.arange(i, i+10) for i in range(1000)] | |
result_parallel = parallel_map(task, num_shards) | |
result_seq = [task(n) for n in num_shards] | |
assert(result_seq == result_parallel) | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment