Created
October 25, 2019 23:59
-
-
Save iAnanich/f3cae79eaa4b913d0f4d38f95620365b to your computer and use it in GitHub Desktop.
AsyncIO Pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
import asyncio | |
class Layer: | |
class STATES: | |
IDLE = 1 | |
RUNNING = 2 | |
GOING_TO_STOP = 3 | |
STOPPED = 4 | |
class DEFAULT: | |
QUEUE_MAX_SIZE = ... | |
needs_next_layer: bool = False | |
next_layer_type: typing.Type['Layer'] or None = None | |
def __init__(self, queue_max_size: int = DEFAULT.QUEUE_MAX_SIZE): | |
self.next_layer = None | |
self.queue_max_size = queue_max_size | |
self.state = self.STATES.IDLE | |
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_max_size) | |
self.running_task: asyncio.Task = None | |
self.started_event = asyncio.Event() | |
self.stopping_event = asyncio.Event() | |
self.stopped_event = asyncio.Event() | |
def connect_next_layer(self, next_layer: 'Layer'): | |
if not isinstance(next_layer, self.next_layer_type or Layer): | |
raise TypeError | |
self.next_layer = next_layer | |
async def start(self): | |
self.state = self.STATES.RUNNING | |
self.started_event.set() | |
self.running_task = asyncio.create_task(self._start()) | |
await self.running_task | |
await self.stop() | |
async def stop(self): | |
self.state = self.STATES.GOING_TO_STOP | |
self.stopping_event.set() | |
await self.queue.join() | |
await self._stop() | |
self.running_task.cancel() | |
self.state = self.STATES.STOPPED | |
self.stopped_event.set() | |
async def _start(self): | |
pass | |
async def _stop(self): | |
pass | |
async def stop_at_event(self, event: asyncio.Event): | |
await event.wait() | |
await self.stop() | |
async def forward_item(self, obj): | |
await self.next_layer.queue.put(obj) | |
async def read_item(self): | |
return await self.queue.get() | |
def done_item(self): | |
self.queue.task_done() | |
def cancel(self): | |
self.running_task.cancel() | |
class Pipeline: | |
def __init__(self, layers: typing.Sequence[Layer]): | |
self.layers = tuple(layers) | |
self._connect_layers() | |
self.start_layers_future = self._create_start_future() | |
self.stop_layers_future = self._create_stop_future() | |
self.stop_self_task: asyncio.Task = None | |
self.running_future: asyncio.Future = None | |
async def start(self): | |
self.stop_self_task = asyncio.create_task( | |
self.stop_at_event(self.layers[-1].stopped_event) | |
) | |
try: | |
self.running_future = asyncio.gather( | |
self.start_layers_future, | |
self.stop_layers_future, | |
self.stop_self_task, | |
return_exceptions=True, | |
) | |
except Exception as exc: | |
print(exc) | |
await self.running_future | |
await self.stop() | |
async def stop(self): | |
for layer in self.layers: | |
await layer.stop() | |
layer.cancel() | |
self.start_layers_future.cancel() | |
async def stop_at_event(self, event: asyncio.Event): | |
await event.wait() | |
await self.stop() | |
def _create_start_future(self) -> asyncio.Future: | |
coros = [layer.start() for layer in self.layers] | |
return asyncio.gather(*coros, return_exceptions=True,) | |
def _create_stop_future(self) -> asyncio.Future: | |
coros = [] | |
for idx in range(1, len(self.layers)): | |
layer = self.layers[idx-1] | |
layer_to_stop = self.layers[idx] | |
coro = layer_to_stop.stop_at_event(event=layer.stopped_event) | |
coros.append(coro) | |
return asyncio.gather(*coros, return_exceptions=True,) | |
def _connect_layers(self): | |
for idx in range(1, len(self.layers)): | |
prev_layer = self.layers[idx-1] | |
next_layer = self.layers[idx] | |
prev_layer.connect_next_layer(next_layer) | |
class ALayer(Layer): | |
async def _start(self): | |
print('a') | |
for i in range(5): | |
print(f'put {i}') | |
await asyncio.sleep(0.1) | |
await self.forward_item(i) | |
class BLayer(Layer): | |
async def _start(self): | |
print('b') | |
while True: | |
i = await self.read_item() | |
await asyncio.sleep(0.1) | |
print(f'got {i}') | |
self.done_item() | |
async def main(): | |
la = ALayer() | |
lb = BLayer() | |
p = Pipeline([la, lb]) | |
await p.start() | |
if __name__ == '__main__': | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment