Skip to content

Instantly share code, notes, and snippets.

@kwikwag
Created October 24, 2025 16:56
Show Gist options
  • Save kwikwag/fff005d1fb25981cf8bf6d774ca6884b to your computer and use it in GitHub Desktop.
Save kwikwag/fff005d1fb25981cf8bf6d774ca6884b to your computer and use it in GitHub Desktop.
import multiprocessing.managers
import multiprocessing.pool
import queue
from contextlib import contextmanager
import torch
_was_init = False
class QueueWithMPStubs(queue.Queue):
"""Same as queue.Queue but with stubs that make it look like multiprocessing.Queue
multiprocessing.Queue cannot be passed to Pool.apply_async.
Instead we use Manager.Queue, but that has an interface like queue.Queue.
So we stub some functions that are missing. Since we manage the process pool
ourselves, I don't think they should matter.
"""
def cancel_join_thread(self):
pass
def close(self):
pass
def patch_attr(obj, patch: dict) -> dict:
"""Monkeypatch an object's properties and return the original values so the
patch can be reversed.
:param obj: the object to be monkeypatched
:param patch: new property values map
:return: old property values map
"""
prev_values = {}
for k, v in patch.items():
prev_values[k] = getattr(obj, k)
setattr(obj, k, v)
return prev_values
@contextmanager
def reusable_pool_context(pool: multiprocessing.pool.Pool):
"""Obtain a multiprocessing_context for a DataLoader that uses the existing pool processes.
Use as a context manager (with `with`). When providing this context to the DataLoader, you
must specify num_workers to be equal to the number of processes in the pool.
>>> with torch.multiprocessing.Pool() as pool, \
... reusable_pool_context(pool) as multiprocessing_context:
... # ...
... dataloader = DataLoader(
... # ...
... multiprocessing_context=multiprocessing_context,
... num_workers=pool._processes, # or less
... )
The multiprocessing_context is the default context, monkeypatched to use the given pool's
processes rather than letting DataLoader create its own.
This saves some time in setup/teardown of DataLoaders.
NOTE: this implementation is fragile and relies on DataLoader's inner-workings to work in
very specific manner. This was tested with PyTorch 2.4.1. As long as you don't iterate
over multiple DataLoaders at once, you should be okay.
:param pool: a multiprocessing.Pool instance
"""
global _was_init
if not pool._pool:
raise RuntimeError("Pool is not running")
if not _was_init:
multiprocessing.managers.SyncManager.register("QueueWithMPStubs", QueueWithMPStubs)
_was_init = True
proc_idx = 0
multiprocessing_context = torch.multiprocessing.get_context()
with torch.multiprocessing.Manager() as manager:
# DataLoader iter calls this upon construction, so we use it to reset the process assignment
def EventWrapper():
nonlocal proc_idx
proc_idx = 0
return manager.Event()
class ProcessWrapper:
"""
Minimal facade that looks like multiprocessing.Process but routes
.start() to pool.apply_async(target, args) on a fixed worker slot.
"""
def __init__(self, target, args):
nonlocal proc_idx, processes
if proc_idx >= len(pool._pool):
raise RuntimeError(
f"Trying to start more processes {proc_idx=} than in pool {len(pool._pool)}. "
"This can happen if specifying an incorrect value for num_workers in your "
"DataLoader or if iterating over more than one DataLoader at the same time."
)
self._pool = pool
self._idx = proc_idx
self._target = target
self._args = args
self._async_results = []
processes[proc_idx] = self
proc_idx += 1
def start(self):
async_result = self._pool.apply_async(self._target, self._args)
self._async_results.append(async_result)
@property
def name(self):
name = self._pool._pool[self._idx].name
return name
@property
def pid(self):
pid = self._pool._pool[self._idx].pid
return pid
def is_alive(self):
# Pool itself is alive and the underlying worker process exists.
if self._pool._state != multiprocessing.pool.RUN:
return False
# we check that all function calls are either running, complete or erroring
for async_result in self._async_results:
if async_result.ready():
# if the worker function errored, this gives it the opportunity to show
async_result.get()
p = self._pool._pool[self._idx]
return p.is_alive()
def join(self, timeout=None):
if timeout is None or timeout == 0:
timeout = 10_000
for async_result in self._async_results:
async_result.wait(timeout)
def terminate(self):
# no worker loop running
self._async_results.clear()
@property
def daemon(self):
return True
@daemon.setter
def daemon(self, _value: bool):
pass
processes: list[ProcessWrapper | None] = [None] * len(pool._pool)
old_values = patch_attr(
multiprocessing_context,
{"Process": ProcessWrapper, "Queue": manager.QueueWithMPStubs, "Event": EventWrapper}
)
try:
yield multiprocessing_context
finally:
patch_attr(multiprocessing_context, old_values)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment