Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active March 30, 2022 13:01
Show Gist options
  • Select an option

  • Save altescy/e0ca715baae602d292bfbc466ae379f6 to your computer and use it in GitHub Desktop.

Select an option

Save altescy/e0ca715baae602d292bfbc466ae379f6 to your computer and use it in GitHub Desktop.
import os
import sys
import threading
import time
from collections import abc
from types import TracebackType
from typing import (Any, AsyncIterable, AsyncIterator, Callable, Dict,
Iterable, Iterator, List, Optional, TextIO, Type, TypeVar,
Union)
T = TypeVar("T")
class Progress:
def __init__(
self,
total_or_iterable: Optional[Union[int, Iterable[T], AsyncIterable[T]]] = None,
*,
title: Optional[str] = None,
unit: str = "it",
output: Optional[TextIO] = None,
callback: Optional[Callable[["Progress"], None]] = None,
partchars: str = " ▏▎▍▌▋▊▉",
maxwidth: Optional[int] = None,
) -> None:
self.title = title
self.unit = unit
self.output = output or sys.stderr
self.callback = callback
self.partchars = partchars
self.maxwidth = maxwidth
self.start_time = time.time()
self.elapsed_time = 0.0
self.value = 0
self.total: Optional[int] = None
self.iterable: Optional[
Union[Iterable[T], Iterable[int], AsyncIterable[T], AsyncIterable[int]]
] = None
if isinstance(total_or_iterable, int):
self.total = total_or_iterable
self.iterable = range(self.total)
elif isinstance(total_or_iterable, abc.Sized):
self.total = len(total_or_iterable) # type: ignore[unreachable]
self.iterable = total_or_iterable
elif isinstance(total_or_iterable, (abc.Iterable, abc.AsyncIterable)):
self.iterable = total_or_iterable
elif total_or_iterable is not None:
raise ValueError("The iterable argument must be iterable.")
self.title_width: Optional[int] = None
self.value_width = len(str(self.total)) if self.total is not None else None
def __iter__(self) -> Union[Iterator[T], Iterator[int]]:
self.start_time = time.time()
if self.iterable is None:
raise RuntimeError(f"{self} is not iterable.")
if isinstance(self.iterable, abc.AsyncIterable):
raise RuntimeError(
"__iter__ is unavailable because the given iterable is AsyncIterable."
)
for x in self.iterable:
yield x
self.update()
async def __aiter__(self) -> Union[AsyncIterator[T], AsyncIterator[int]]:
self.start_time = time.time()
if self.iterable is None:
raise RuntimeError(f"{self} is not iterable.")
if isinstance(self.iterable, abc.Iterable):
raise RuntimeError(
"__aiter__ is unavailable because the given iterable is Iterable."
)
async for x in self.iterable:
yield x
self.update()
def update(
self,
value: int = 1,
) -> None:
self.elapsed_time = time.time() - self.start_time
self.value += value
if self.callback is None:
self.show()
else:
self.callback(self)
def __enter__(self) -> "Progress":
self.output.write("\x1b[?25l")
self.output.flush()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
self.output.write("\x1b[?25h")
self.output.write("\n")
self.output.flush()
return exc_type is None and exc_value is None and traceback is None
@staticmethod
def _format_time(seconds: float) -> str:
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return f"{int(h):d}:{int(m):02d}:{s:04.1f}"
def show(self) -> None:
elapsed_time = self.elapsed_time or (time.time() - self.start_time)
line = ""
components: Dict[str, Any] = {}
if self.title is not None:
line += "{title} "
components[
"title"
] = f"{self.title:<{self.title_width or len(self.title)}}:"
if self.total is not None:
line += (
"{percentage:5.1f}% "
"|{bar:{bar_width}s}| "
"{value:>{value_width}}/{total:>{value_width}} "
"[{elapsed_time}<{remaining_time}, {average_its:.2f}{unit}/s]"
)
if self.value:
remaining_time = (self.total - self.value) * elapsed_time / self.value
else:
remaining_time = 0
components["percentage"] = 100 * self.value / self.total
components["bar"] = "="
components["bar_width"] = 1
components["value"] = self.value
components["total"] = self.total
components["value_width"] = self.value_width or len(str(self.total))
components["remaining_time"] = self._format_time(remaining_time)
else:
line += "{value}{unit} [{elapsed_time}, {average_its:.2f}{unit}/s]"
components["value"] = f"{self.value}"
components["elapsed_time"] = self._format_time(elapsed_time)
components["unit"] = self.unit
components["average_its"] = self.value / elapsed_time
if self.total:
terminal_width, _ = os.get_terminal_size()
total_width = (
terminal_width
if self.maxwidth is None
else min(self.maxwidth, terminal_width)
)
_line = line.format(**components)
width = max(1, total_width - len(_line))
ratio = self.value / self.total
whole_width = int(ratio * width)
part_width = int(len(self.partchars) * ((ratio * width) % 1))
part_char = self.partchars[part_width]
components["bar_width"] = width
components["bar"] = (self.partchars[-1] * whole_width + part_char)[:width]
line = line.format(**components)
self.output.write("\x1b[2K\r")
self.output.write(f"{line}\r")
class ProgressGroup:
def __init__(
self,
*progresses: Progress,
output: Optional[TextIO] = None,
) -> None:
if output is not None and not output.writable():
raise ValueError("The given output is not writable.")
self.progresses: List[Progress] = []
self.output = output or sys.stderr
self.lock = threading.Lock()
title_width = max(len(progress.title or "") for progress in progresses)
value_width = max(progress.value_width or 0 for progress in progresses) or None
def callback(_: Progress) -> None:
self.show()
for progress in progresses:
progress.title_width = title_width
progress.value_width = value_width
progress.output = self.output
progress.callback = callback
self.progresses.append(progress)
def __enter__(self) -> "ProgressGroup":
self.output.write("\n" * len(self.progresses))
self.output.write("\x1b[?25l")
self.output.flush()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
self.output.write("\x1b[?25h")
self.output.flush()
return exc_type is None and exc_value is None and traceback is None
def show(self) -> None:
with self.lock:
self.output.write(f"\x1b[{len(self.progresses)}A")
for progress in self.progresses:
progress.show()
self.output.write("\n")
self.output.flush()
if __name__ == "__main__":
import asyncio
from concurrent.futures import ThreadPoolExecutor, wait
#
# Multi-threading
#
def task_threading(progress: Progress):
for _ in progress:
time.sleep(0.01)
def main_threading() -> None:
foo = Progress(100, title="foo - threading")
bar = Progress(50, title="bar - threading")
with ThreadPoolExecutor(max_workers=2) as executor:
with ProgressGroup(foo, bar):
futures = [
executor.submit(task_threading, foo),
executor.submit(task_threading, bar),
]
wait(futures)
main_threading()
#
# Async
#
async def task_async(progress: Progress) -> None:
for _ in progress:
await asyncio.sleep(0.01)
async def main_async() -> None:
foo = Progress(100, title="foo - async")
bar = Progress(50, title="bar - async")
with ProgressGroup(foo, bar):
await asyncio.wait([task_async(foo), task_async(bar)])
loop = asyncio.get_event_loop()
loop.run_until_complete(main_async())
#
# Asycn Iterator
#
async def aiter(n: int) -> AsyncIterator[int]:
for x in range(n):
await asyncio.sleep(0.01)
yield x
async def task_aiter(progress: Progress):
async for _ in progress:
pass
async def main_aiter() -> None:
foo = Progress(aiter(100), title="foo - async iterator", unit="it")
bar = Progress(aiter(50), title="bar - async iterator", unit="it")
with ProgressGroup(foo, bar):
await asyncio.wait([task_aiter(foo), task_aiter(bar)])
loop = asyncio.get_event_loop()
loop.run_until_complete(main_aiter())
#
# Simple
#
with Progress(title="baz - iterator", unit="it") as baz:
for _ in range(50):
time.sleep(0.05)
baz.update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment