Skip to content

Instantly share code, notes, and snippets.

@altescy
Created January 8, 2023 08:15
Show Gist options
  • Save altescy/ae65acc225b9cc46b2f9f2c0f75d461d to your computer and use it in GitHub Desktop.
Save altescy/ae65acc225b9cc46b2f9f2c0f75d461d to your computer and use it in GitHub Desktop.
Iterator wrapper generating values with multi-processing queue
import time
from collections.abc import Iterable, Iterator
from multiprocessing import Process, Queue
from types import TracebackType
from typing import Generic, Type, TypeVar
T = TypeVar("T")
class NumberIterator:
def __init__(self, maxval: int) -> None:
self._maxval = maxval
self._current = 0
def __iter__(self) -> Iterator[int]:
return self
def __next__(self) -> int:
time.sleep(0.1)
if self._current >= self._maxval:
raise StopIteration
self._current += 1
return self._current
class QueueEnd:
...
class MultiProcessQueueIterator(Generic[T]):
def __init__(
self,
iterator: Iterator[T],
maxsize: int = 0,
timeout: float | None = None,
) -> None:
self._iterator = iterator
self._timeout = timeout
self._queue: Queuer[T] = Queue(maxsize=maxsize)
self._process: Process | None = None
def __enter__(self) -> Iterator:
self._process = Process(target=self._run)
self._process.start()
return self
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
if self._process is not None:
self._process.terminate()
self._process = None
return False
def __iter__(self) -> Iterator[T]:
return self
def __next__(self) -> T:
value = self._queue.get(timeout=self._timeout)
if isinstance(value, QueueEnd):
raise StopIteration
return value
def _run(self) -> None:
for item in self._iterator:
print(f"put {item}")
self._queue.put(item)
self._queue.put(QueueEnd())
def into_batch(iterable: Iterable[T], batch_size: int) -> Iterable[T]:
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
def main() -> None:
with MultiProcessQueueIterator(NumberIterator(50), maxsize=5) as iterator:
for batch in into_batch(iterator, 3):
print("waiting for batch to be processed")
time.sleep(0.5)
print(f"get {batch}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment