Created
January 8, 2023 08:15
-
-
Save altescy/ae65acc225b9cc46b2f9f2c0f75d461d to your computer and use it in GitHub Desktop.
Iterator wrapper generating values with multi-processing queue
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from collections.abc import Iterable, Iterator | |
from multiprocessing import Process, Queue | |
from types import TracebackType | |
from typing import Generic, Type, TypeVar | |
T = TypeVar("T") | |
class NumberIterator: | |
def __init__(self, maxval: int) -> None: | |
self._maxval = maxval | |
self._current = 0 | |
def __iter__(self) -> Iterator[int]: | |
return self | |
def __next__(self) -> int: | |
time.sleep(0.1) | |
if self._current >= self._maxval: | |
raise StopIteration | |
self._current += 1 | |
return self._current | |
class QueueEnd: | |
... | |
class MultiProcessQueueIterator(Generic[T]): | |
def __init__( | |
self, | |
iterator: Iterator[T], | |
maxsize: int = 0, | |
timeout: float | None = None, | |
) -> None: | |
self._iterator = iterator | |
self._timeout = timeout | |
self._queue: Queuer[T] = Queue(maxsize=maxsize) | |
self._process: Process | None = None | |
def __enter__(self) -> Iterator: | |
self._process = Process(target=self._run) | |
self._process.start() | |
return self | |
def __exit__( | |
self, | |
exc_type: Type[BaseException] | None, | |
exc_value: BaseException | None, | |
traceback: TracebackType | None, | |
) -> bool: | |
if self._process is not None: | |
self._process.terminate() | |
self._process = None | |
return False | |
def __iter__(self) -> Iterator[T]: | |
return self | |
def __next__(self) -> T: | |
value = self._queue.get(timeout=self._timeout) | |
if isinstance(value, QueueEnd): | |
raise StopIteration | |
return value | |
def _run(self) -> None: | |
for item in self._iterator: | |
print(f"put {item}") | |
self._queue.put(item) | |
self._queue.put(QueueEnd()) | |
def into_batch(iterable: Iterable[T], batch_size: int) -> Iterable[T]: | |
batch = [] | |
for item in iterable: | |
batch.append(item) | |
if len(batch) == batch_size: | |
yield batch | |
batch = [] | |
if batch: | |
yield batch | |
def main() -> None: | |
with MultiProcessQueueIterator(NumberIterator(50), maxsize=5) as iterator: | |
for batch in into_batch(iterator, 3): | |
print("waiting for batch to be processed") | |
time.sleep(0.5) | |
print(f"get {batch}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment