|
import asyncio |
|
from contextlib import contextmanager |
|
from copy import deepcopy |
|
import functools |
|
import itertools |
|
import logging |
|
import os |
|
import queue |
|
import random |
|
import socket |
|
import socketserver |
|
import threading |
|
import time |
|
import typing as T |
|
|
|
logging.basicConfig() |
|
log = logging.getLogger() |
|
log.setLevel(os.environ.get("LOG_LVL", logging.INFO)) |
|
|
|
BYTE_INT = {"length": 4, "byteorder": "little"} |
|
|
|
|
|
def endow_with_iterator(cls, iterator): |
|
_iterator = iterator |
|
|
|
def custom_init(cls, *args, **kwargs): |
|
if "__next__" not in dir(_iterator): |
|
raise ValueError(f"{_iterator} is not iterable.") |
|
super(WrappedClass, cls).__init__(*args, **kwargs) |
|
log.info(f"Class {cls} inited") |
|
cls.iterator = deepcopy(_iterator) |
|
|
|
class WrappedClass(cls): |
|
iterator = _iterator |
|
__init__ = custom_init |
|
|
|
return WrappedClass |
|
|
|
|
|
class IterHandler(socketserver.BaseRequestHandler): |
|
def handle(self): |
|
log.debug("Incoming request: %s", self.request) |
|
resp = next(self.iterator).to_bytes(**BYTE_INT) |
|
time.sleep(1) |
|
self.request.sendall(resp) |
|
log.debug("Response sent: %s", resp) |
|
|
|
|
|
def serve_iterable( |
|
conn: T.Tuple[str, int], iterator: T.Iterator, q: queue.Queue |
|
) -> None: |
|
h = endow_with_iterator(IterHandler, iterator) |
|
s = socketserver.ThreadingTCPServer(conn, h) |
|
log.info("Serving at %s:%s", *conn) |
|
t_server = threading.Thread( |
|
name="main server thread", daemon=True, target=s.serve_forever |
|
) |
|
|
|
def control() -> None: |
|
shutdown = False |
|
while not shutdown: |
|
shutdown = q.get() |
|
|
|
s.shutdown() |
|
|
|
t_ctrl = threading.Thread(name="control thread", daemon=True, target=control) |
|
|
|
t_server.start() |
|
log.info("Server thread started") |
|
t_ctrl.start() |
|
log.info("Control thread started") |
|
|
|
|
|
async def fetch(conn: T.Tuple[str, int], size: int) -> bytes: |
|
reader, _ = await asyncio.open_connection(*conn) |
|
res = await reader.read(size) |
|
return res |
|
|
|
|
|
@contextmanager |
|
def setup(conn: T.Tuple[str, int], it: T.Iterator) -> None: |
|
q = queue.Queue() |
|
|
|
try: |
|
serve_iterable(conn, it, q) |
|
yield |
|
finally: |
|
log.debug("Exiting...") |
|
q.put(True) |
|
log.info("Shutdown.") |
|
|
|
|
|
async def test_retrieve(n: int) -> None: |
|
conn_par = ("localhost", random.randint(1025, 50000)) |
|
it = itertools.cycle((1, 2, 3)[::-1]) |
|
with setup(conn_par, it): |
|
for i in range(n): |
|
res = await fetch(conn_par, BYTE_INT["length"]) |
|
print(int.from_bytes(res, BYTE_INT["byteorder"])) |
|
|
|
|
|
if __name__ == "__main__": |
|
asyncio.run(test_retrieve(int(os.environ.get("ITER_TIMES", 5)))) |