Last active
May 27, 2025 11:57
-
-
Save tsvikas/5f859a484e53d4ef93400751d0a116de to your computer and use it in GitHub Desktop.
joblib.Parallel, but with a tqdm progressbar
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tqdm | |
from joblib import Parallel, delayed | |
class ParallelTqdm(Parallel): | |
"""joblib.Parallel, but with a tqdm progressbar | |
Additional parameters: | |
---------------------- | |
n_jobs: int, default: 0 | |
The maximum number of concurrently running jobs. | |
Will pass `n_jobs if int(n_jobs)>0 else n_jobs-1` to joblib Parallel. | |
i.e.: 1 means no parallel computing, n means n CPUs, 0 means all CPUs, | |
and -n means (num_cpu-n) CPUs. | |
total_tasks: int, default: None | |
The number of expected jobs. Used in the tqdm progressbar. | |
If None, try to infer from the length of the called iterator, and | |
fallback to use the number of remaining items as soon as we finish | |
dispatching. | |
Note: use a list instead of an iterator if you want the total_tasks | |
to be inferred from its length. | |
show_joblib_header: bool, default: False | |
If True, show joblib header before the progressbar. | |
parallel_kwargs: dict, optional | |
kwargs to pass to `joblib.Parralel` | |
tqdm_kwargs: dict, optional | |
kwargs to pass to `tqdm.tqdm`, like 'desc', 'ncols', 'disable'. | |
Usage: | |
------ | |
>>> from joblib_parallel_with_tqdm import ParallelTqdm, delayed | |
>>> from time import sleep | |
>>> ParallelTqdm(n_jobs=-1)([delayed(sleep)(i) for i in range(10)]) | |
80%|████████ | 8/10 [00:07<00:01, 1.08tasks/s] | |
""" | |
def __init__( | |
self, | |
n_jobs: int = 0, | |
*, | |
total_tasks: int | None = None, | |
show_joblib_header: bool = False, | |
parallel_kwargs: dict | None = None, | |
tqdm_kwargs: dict | None = None, | |
): | |
# set the Parallel class | |
if parallel_kwargs is None: | |
parallel_kwargs = {} | |
parallel_kwargs["verbose"] = 1 if show_joblib_header else 0 | |
parallel_kwargs["n_jobs"] = n_jobs if int(n_jobs)>0 else n_jobs - 1 | |
super().__init__(**parallel_kwargs) | |
# prepare the tqdm kwargs | |
if tqdm_kwargs is None: | |
tqdm_kwargs = {} | |
if "iterable" in tqdm_kwargs: | |
raise TypeError( | |
"keyword argument 'iterable' is not supported in 'tqdm_kwargs'." | |
) | |
if "total" in tqdm_kwargs: | |
total_from_tqdm = tqdm_kwargs.pop("total") | |
if total_tasks is None: | |
total_tasks = total_from_tqdm | |
elif total_tasks != total_from_tqdm: | |
raise ValueError( | |
"keyword argument 'total' for tqdm_kwargs is specified and different from 'total_tasks'" | |
) | |
self.tqdm_kwargs = dict(unit="tasks") | tqdm_kwargs | |
self.total_tasks = total_tasks | |
self.progress_bar: tqdm.tqdm | None = None | |
def __call__(self, iterable): | |
try: | |
if self.total_tasks is None: | |
# try to infer total_tasks from the length of the called iterator | |
try: | |
self.total_tasks = len(iterable) | |
except (TypeError, AttributeError): | |
pass | |
# call parent function | |
return super().__call__(iterable) | |
finally: | |
# close tqdm progress bar | |
if self.progress_bar is not None: | |
self.progress_bar.close() | |
__call__.__doc__ = Parallel.__call__.__doc__ | |
def dispatch_one_batch(self, iterator): | |
# start progress_bar, if not started yet. | |
if self.progress_bar is None: | |
self.progress_bar = tqdm.tqdm( | |
total=self.total_tasks, | |
**self.tqdm_kwargs, | |
) | |
# call parent function | |
return super().dispatch_one_batch(iterator) | |
dispatch_one_batch.__doc__ = Parallel.dispatch_one_batch.__doc__ | |
def print_progress(self): | |
"""Display the process of the parallel execution using tqdm""" | |
# if we finish dispatching, find total_tasks from the number of remaining items | |
if self.total_tasks is None and self._original_iterator is None: | |
self.progress_bar.total = self.total_tasks = self.n_dispatched_tasks | |
self.progress_bar.refresh() | |
# update progressbar | |
self.progress_bar.update(self.n_completed_tasks - self.progress_bar.n) |
@tsvikas This is so convenient! Any interest in making a small Python package with this and publish it on PyPI?
Thanks!
I agree it's convenient, but from what I’ve seen, joblib doesn’t seem interested in supporting this feature. Keeping a separate package aligned with upstream changes could turn into a lot of overhead.
If this is something the community wants, maybe it’s worth revisiting the discussion in joblib/joblib#972? I’d definitely be happy to see native support there.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for this, just referencing the issue: joblib/joblib#972