Last active
March 30, 2022 13:01
-
-
Save altescy/e0ca715baae602d292bfbc466ae379f6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import sys | |
| import threading | |
| import time | |
| from collections import abc | |
| from types import TracebackType | |
| from typing import (Any, AsyncIterable, AsyncIterator, Callable, Dict, | |
| Iterable, Iterator, List, Optional, TextIO, Type, TypeVar, | |
| Union) | |
| T = TypeVar("T") | |
| class Progress: | |
| def __init__( | |
| self, | |
| total_or_iterable: Optional[Union[int, Iterable[T], AsyncIterable[T]]] = None, | |
| *, | |
| title: Optional[str] = None, | |
| unit: str = "it", | |
| output: Optional[TextIO] = None, | |
| callback: Optional[Callable[["Progress"], None]] = None, | |
| partchars: str = " ▏▎▍▌▋▊▉", | |
| maxwidth: Optional[int] = None, | |
| ) -> None: | |
| self.title = title | |
| self.unit = unit | |
| self.output = output or sys.stderr | |
| self.callback = callback | |
| self.partchars = partchars | |
| self.maxwidth = maxwidth | |
| self.start_time = time.time() | |
| self.elapsed_time = 0.0 | |
| self.value = 0 | |
| self.total: Optional[int] = None | |
| self.iterable: Optional[ | |
| Union[Iterable[T], Iterable[int], AsyncIterable[T], AsyncIterable[int]] | |
| ] = None | |
| if isinstance(total_or_iterable, int): | |
| self.total = total_or_iterable | |
| self.iterable = range(self.total) | |
| elif isinstance(total_or_iterable, abc.Sized): | |
| self.total = len(total_or_iterable) # type: ignore[unreachable] | |
| self.iterable = total_or_iterable | |
| elif isinstance(total_or_iterable, (abc.Iterable, abc.AsyncIterable)): | |
| self.iterable = total_or_iterable | |
| elif total_or_iterable is not None: | |
| raise ValueError("The iterable argument must be iterable.") | |
| self.title_width: Optional[int] = None | |
| self.value_width = len(str(self.total)) if self.total is not None else None | |
| def __iter__(self) -> Union[Iterator[T], Iterator[int]]: | |
| self.start_time = time.time() | |
| if self.iterable is None: | |
| raise RuntimeError(f"{self} is not iterable.") | |
| if isinstance(self.iterable, abc.AsyncIterable): | |
| raise RuntimeError( | |
| "__iter__ is unavailable because the given iterable is AsyncIterable." | |
| ) | |
| for x in self.iterable: | |
| yield x | |
| self.update() | |
| async def __aiter__(self) -> Union[AsyncIterator[T], AsyncIterator[int]]: | |
| self.start_time = time.time() | |
| if self.iterable is None: | |
| raise RuntimeError(f"{self} is not iterable.") | |
| if isinstance(self.iterable, abc.Iterable): | |
| raise RuntimeError( | |
| "__aiter__ is unavailable because the given iterable is Iterable." | |
| ) | |
| async for x in self.iterable: | |
| yield x | |
| self.update() | |
| def update( | |
| self, | |
| value: int = 1, | |
| ) -> None: | |
| self.elapsed_time = time.time() - self.start_time | |
| self.value += value | |
| if self.callback is None: | |
| self.show() | |
| else: | |
| self.callback(self) | |
| def __enter__(self) -> "Progress": | |
| self.output.write("\x1b[?25l") | |
| self.output.flush() | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: Optional[Type[BaseException]], | |
| exc_value: Optional[BaseException], | |
| traceback: Optional[TracebackType], | |
| ) -> bool: | |
| self.output.write("\x1b[?25h") | |
| self.output.write("\n") | |
| self.output.flush() | |
| return exc_type is None and exc_value is None and traceback is None | |
| @staticmethod | |
| def _format_time(seconds: float) -> str: | |
| m, s = divmod(seconds, 60) | |
| h, m = divmod(m, 60) | |
| return f"{int(h):d}:{int(m):02d}:{s:04.1f}" | |
| def show(self) -> None: | |
| elapsed_time = self.elapsed_time or (time.time() - self.start_time) | |
| line = "" | |
| components: Dict[str, Any] = {} | |
| if self.title is not None: | |
| line += "{title} " | |
| components[ | |
| "title" | |
| ] = f"{self.title:<{self.title_width or len(self.title)}}:" | |
| if self.total is not None: | |
| line += ( | |
| "{percentage:5.1f}% " | |
| "|{bar:{bar_width}s}| " | |
| "{value:>{value_width}}/{total:>{value_width}} " | |
| "[{elapsed_time}<{remaining_time}, {average_its:.2f}{unit}/s]" | |
| ) | |
| if self.value: | |
| remaining_time = (self.total - self.value) * elapsed_time / self.value | |
| else: | |
| remaining_time = 0 | |
| components["percentage"] = 100 * self.value / self.total | |
| components["bar"] = "=" | |
| components["bar_width"] = 1 | |
| components["value"] = self.value | |
| components["total"] = self.total | |
| components["value_width"] = self.value_width or len(str(self.total)) | |
| components["remaining_time"] = self._format_time(remaining_time) | |
| else: | |
| line += "{value}{unit} [{elapsed_time}, {average_its:.2f}{unit}/s]" | |
| components["value"] = f"{self.value}" | |
| components["elapsed_time"] = self._format_time(elapsed_time) | |
| components["unit"] = self.unit | |
| components["average_its"] = self.value / elapsed_time | |
| if self.total: | |
| terminal_width, _ = os.get_terminal_size() | |
| total_width = ( | |
| terminal_width | |
| if self.maxwidth is None | |
| else min(self.maxwidth, terminal_width) | |
| ) | |
| _line = line.format(**components) | |
| width = max(1, total_width - len(_line)) | |
| ratio = self.value / self.total | |
| whole_width = int(ratio * width) | |
| part_width = int(len(self.partchars) * ((ratio * width) % 1)) | |
| part_char = self.partchars[part_width] | |
| components["bar_width"] = width | |
| components["bar"] = (self.partchars[-1] * whole_width + part_char)[:width] | |
| line = line.format(**components) | |
| self.output.write("\x1b[2K\r") | |
| self.output.write(f"{line}\r") | |
| class ProgressGroup: | |
| def __init__( | |
| self, | |
| *progresses: Progress, | |
| output: Optional[TextIO] = None, | |
| ) -> None: | |
| if output is not None and not output.writable(): | |
| raise ValueError("The given output is not writable.") | |
| self.progresses: List[Progress] = [] | |
| self.output = output or sys.stderr | |
| self.lock = threading.Lock() | |
| title_width = max(len(progress.title or "") for progress in progresses) | |
| value_width = max(progress.value_width or 0 for progress in progresses) or None | |
| def callback(_: Progress) -> None: | |
| self.show() | |
| for progress in progresses: | |
| progress.title_width = title_width | |
| progress.value_width = value_width | |
| progress.output = self.output | |
| progress.callback = callback | |
| self.progresses.append(progress) | |
| def __enter__(self) -> "ProgressGroup": | |
| self.output.write("\n" * len(self.progresses)) | |
| self.output.write("\x1b[?25l") | |
| self.output.flush() | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: Optional[Type[BaseException]], | |
| exc_value: Optional[BaseException], | |
| traceback: Optional[TracebackType], | |
| ) -> bool: | |
| self.output.write("\x1b[?25h") | |
| self.output.flush() | |
| return exc_type is None and exc_value is None and traceback is None | |
| def show(self) -> None: | |
| with self.lock: | |
| self.output.write(f"\x1b[{len(self.progresses)}A") | |
| for progress in self.progresses: | |
| progress.show() | |
| self.output.write("\n") | |
| self.output.flush() | |
| if __name__ == "__main__": | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor, wait | |
| # | |
| # Multi-threading | |
| # | |
| def task_threading(progress: Progress): | |
| for _ in progress: | |
| time.sleep(0.01) | |
| def main_threading() -> None: | |
| foo = Progress(100, title="foo - threading") | |
| bar = Progress(50, title="bar - threading") | |
| with ThreadPoolExecutor(max_workers=2) as executor: | |
| with ProgressGroup(foo, bar): | |
| futures = [ | |
| executor.submit(task_threading, foo), | |
| executor.submit(task_threading, bar), | |
| ] | |
| wait(futures) | |
| main_threading() | |
| # | |
| # Async | |
| # | |
| async def task_async(progress: Progress) -> None: | |
| for _ in progress: | |
| await asyncio.sleep(0.01) | |
| async def main_async() -> None: | |
| foo = Progress(100, title="foo - async") | |
| bar = Progress(50, title="bar - async") | |
| with ProgressGroup(foo, bar): | |
| await asyncio.wait([task_async(foo), task_async(bar)]) | |
| loop = asyncio.get_event_loop() | |
| loop.run_until_complete(main_async()) | |
| # | |
| # Asycn Iterator | |
| # | |
| async def aiter(n: int) -> AsyncIterator[int]: | |
| for x in range(n): | |
| await asyncio.sleep(0.01) | |
| yield x | |
| async def task_aiter(progress: Progress): | |
| async for _ in progress: | |
| pass | |
| async def main_aiter() -> None: | |
| foo = Progress(aiter(100), title="foo - async iterator", unit="it") | |
| bar = Progress(aiter(50), title="bar - async iterator", unit="it") | |
| with ProgressGroup(foo, bar): | |
| await asyncio.wait([task_aiter(foo), task_aiter(bar)]) | |
| loop = asyncio.get_event_loop() | |
| loop.run_until_complete(main_aiter()) | |
| # | |
| # Simple | |
| # | |
| with Progress(title="baz - iterator", unit="it") as baz: | |
| for _ in range(50): | |
| time.sleep(0.05) | |
| baz.update() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment