Created
October 24, 2025 16:56
-
-
Save kwikwag/fff005d1fb25981cf8bf6d774ca6884b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import multiprocessing.managers | |
| import multiprocessing.pool | |
| import queue | |
| from contextlib import contextmanager | |
| import torch | |
| _was_init = False | |
| class QueueWithMPStubs(queue.Queue): | |
| """Same as queue.Queue but with stubs that make it look like multiprocessing.Queue | |
| multiprocessing.Queue cannot be passed to Pool.apply_async. | |
| Instead we use Manager.Queue, but that has an interface like queue.Queue. | |
| So we stub some functions that are missing. Since we manage the process pool | |
| ourselves, I don't think they should matter. | |
| """ | |
| def cancel_join_thread(self): | |
| pass | |
| def close(self): | |
| pass | |
| def patch_attr(obj, patch: dict) -> dict: | |
| """Monkeypatch an object's properties and return the original values so the | |
| patch can be reversed. | |
| :param obj: the object to be monkeypatched | |
| :param patch: new property values map | |
| :return: old property values map | |
| """ | |
| prev_values = {} | |
| for k, v in patch.items(): | |
| prev_values[k] = getattr(obj, k) | |
| setattr(obj, k, v) | |
| return prev_values | |
| @contextmanager | |
| def reusable_pool_context(pool: multiprocessing.pool.Pool): | |
| """Obtain a multiprocessing_context for a DataLoader that uses the existing pool processes. | |
| Use as a context manager (with `with`). When providing this context to the DataLoader, you | |
| must specify num_workers to be equal to the number of processes in the pool. | |
| >>> with torch.multiprocessing.Pool() as pool, \ | |
| ... reusable_pool_context(pool) as multiprocessing_context: | |
| ... # ... | |
| ... dataloader = DataLoader( | |
| ... # ... | |
| ... multiprocessing_context=multiprocessing_context, | |
| ... num_workers=pool._processes, # or less | |
| ... ) | |
| The multiprocessing_context is the default context, monkeypatched to use the given pool's | |
| processes rather than letting DataLoader create its own. | |
| This saves some time in setup/teardown of DataLoaders. | |
| NOTE: this implementation is fragile and relies on DataLoader's inner-workings to work in | |
| very specific manner. This was tested with PyTorch 2.4.1. As long as you don't iterate | |
| over multiple DataLoaders at once, you should be okay. | |
| :param pool: a multiprocessing.Pool instance | |
| """ | |
| global _was_init | |
| if not pool._pool: | |
| raise RuntimeError("Pool is not running") | |
| if not _was_init: | |
| multiprocessing.managers.SyncManager.register("QueueWithMPStubs", QueueWithMPStubs) | |
| _was_init = True | |
| proc_idx = 0 | |
| multiprocessing_context = torch.multiprocessing.get_context() | |
| with torch.multiprocessing.Manager() as manager: | |
| # DataLoader iter calls this upon construction, so we use it to reset the process assignment | |
| def EventWrapper(): | |
| nonlocal proc_idx | |
| proc_idx = 0 | |
| return manager.Event() | |
| class ProcessWrapper: | |
| """ | |
| Minimal facade that looks like multiprocessing.Process but routes | |
| .start() to pool.apply_async(target, args) on a fixed worker slot. | |
| """ | |
| def __init__(self, target, args): | |
| nonlocal proc_idx, processes | |
| if proc_idx >= len(pool._pool): | |
| raise RuntimeError( | |
| f"Trying to start more processes {proc_idx=} than in pool {len(pool._pool)}. " | |
| "This can happen if specifying an incorrect value for num_workers in your " | |
| "DataLoader or if iterating over more than one DataLoader at the same time." | |
| ) | |
| self._pool = pool | |
| self._idx = proc_idx | |
| self._target = target | |
| self._args = args | |
| self._async_results = [] | |
| processes[proc_idx] = self | |
| proc_idx += 1 | |
| def start(self): | |
| async_result = self._pool.apply_async(self._target, self._args) | |
| self._async_results.append(async_result) | |
| @property | |
| def name(self): | |
| name = self._pool._pool[self._idx].name | |
| return name | |
| @property | |
| def pid(self): | |
| pid = self._pool._pool[self._idx].pid | |
| return pid | |
| def is_alive(self): | |
| # Pool itself is alive and the underlying worker process exists. | |
| if self._pool._state != multiprocessing.pool.RUN: | |
| return False | |
| # we check that all function calls are either running, complete or erroring | |
| for async_result in self._async_results: | |
| if async_result.ready(): | |
| # if the worker function errored, this gives it the opportunity to show | |
| async_result.get() | |
| p = self._pool._pool[self._idx] | |
| return p.is_alive() | |
| def join(self, timeout=None): | |
| if timeout is None or timeout == 0: | |
| timeout = 10_000 | |
| for async_result in self._async_results: | |
| async_result.wait(timeout) | |
| def terminate(self): | |
| # no worker loop running | |
| self._async_results.clear() | |
| @property | |
| def daemon(self): | |
| return True | |
| @daemon.setter | |
| def daemon(self, _value: bool): | |
| pass | |
| processes: list[ProcessWrapper | None] = [None] * len(pool._pool) | |
| old_values = patch_attr( | |
| multiprocessing_context, | |
| {"Process": ProcessWrapper, "Queue": manager.QueueWithMPStubs, "Event": EventWrapper} | |
| ) | |
| try: | |
| yield multiprocessing_context | |
| finally: | |
| patch_attr(multiprocessing_context, old_values) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment