Last active
January 25, 2020 16:03
-
-
Save Andrei-Pozolotin/f5be7dec56840428c6245a1cef4a25eb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
robust cluster connection pool | |
""" | |
import asyncio | |
import logging | |
from contextvars import ContextVar | |
from typing import Awaitable | |
from typing import List | |
from typing import Mapping | |
import asyncpg | |
import funcy | |
logger = logging.getLogger(__name__) | |
override = lambda func : func | |
"global cluster pool" | |
POOL_DBMS:ContextVar["SmaprtPoolDMBS"] = ContextVar('POOL_DBMS', default=None) | |
class SmaprtPoolDMBS(asyncpg.pool.Pool): | |
"dbms connection pool with automatic failover" | |
# current connection | |
conn_url:str = None | |
# initialization blocker | |
conn_block:asyncio.Condition = None | |
# connection state: url -> has-live | |
status_dict:Mapping[str, bool] = None | |
# track internal jobs for cancel | |
monitor_task_list:List[asyncio.Task] = None | |
monitor_period:float = None | |
@override | |
def __init__(self, | |
# cluster members, main first | |
url_list:List[str], | |
# dbms open connect timeout, seconds | |
connect_timeout:float=3, | |
# period for pool monitor tasks | |
monitor_period:float=1, | |
) -> None: | |
self.conn_url = None | |
self.connect_timeout = connect_timeout | |
self.monitor_period = monitor_period | |
self.conn_block = asyncio.locks.Condition() | |
self.monitor_task_list = list() | |
self.status_dict = dict([(url, False) for url in url_list]) | |
connect_kwargs = dict( | |
timeout=connect_timeout, | |
) | |
# TODO config | |
super().__init__( | |
None, | |
min_size=1, | |
max_size=10, | |
max_queries=50000, | |
max_inactive_connection_lifetime=300, | |
setup=None, | |
init=None, | |
loop=None, | |
connection_class=asyncpg.connection.Connection, | |
**connect_kwargs, | |
) | |
@override | |
async def _async__init__(self) -> "PoolDMBS": | |
"setup cluster connection monitor" | |
if self._initialized: | |
return self | |
self.spawn_task(self.monitor_main(), name="pool-main") | |
await self.block_conn_url() | |
return await super()._async__init__() | |
@override | |
async def _acquire(self, timeout:float) -> "PoolConnectionProxy": | |
"setup cluster connection proxy" | |
await self.block_conn_url() | |
return await super()._acquire(timeout=timeout) | |
@override | |
async def close(self): | |
"terminate monitor tasks" | |
await self.monitor_close() | |
await super().close() | |
@override | |
def terminate(self): | |
"terminate monitor tasks" | |
self.monitor_terminate() | |
super().terminate() | |
def spawn_task(self, coro:Awaitable, name:str=None) -> asyncio.Task: | |
"register pool monitor task" | |
pool_task = asyncio.create_task(coro, name=name) | |
self.monitor_task_list.append(pool_task) | |
return pool_task | |
def has_conn_url(self) -> bool: | |
"check if connection is available" | |
return self.conn_url is not None | |
async def block_conn_url(self): | |
"block till connection is available" | |
if not self.has_conn_url(): | |
async with self.conn_block: | |
await self.conn_block.wait_for(self.has_conn_url) | |
def active_conn_url(self) -> str: | |
"extract current connection url" | |
status_dict = funcy.filter( | |
lambda status_item : status_item[1] == True, | |
self.status_dict.items() | |
) | |
status_item = funcy.first(status_dict) | |
if status_item: | |
return status_item[0] | |
else: | |
return None | |
async def monitor_close(self) -> None: | |
"terminate monitor tasks" | |
self.monitor_terminate() | |
await asyncio.gather(*self.monitor_task_list, return_exceptions=True) | |
def monitor_terminate(self) -> None: | |
"terminate monitor tasks" | |
for monitor_task in self.monitor_task_list: | |
monitor_task.cancel() | |
async def monitor_main(self) -> None: | |
"track connection state for all cluster members" | |
for url in self.status_dict.keys(): | |
self.spawn_task(self.monitor_ping(url), name="pool-ping") | |
while True: | |
conn_url = self.active_conn_url() | |
if conn_url != self.conn_url: | |
logger.info(f"activate: {conn_url}") | |
self.set_connect_args(conn_url) | |
await self.expire_connections() | |
async with self.conn_block: | |
self.conn_url = conn_url | |
self.conn_block.notify_all() | |
await self.monitor_sleep() | |
async def monitor_ping(self, url:str) -> None: | |
"track single member connection live/dead state" | |
while True: | |
try: | |
ping_conn = await asyncpg.connect( | |
url, | |
timeout=self.connect_timeout, | |
) | |
if self.status_dict[url] == False: | |
logger.info(f"success: {url}") | |
self.status_dict[url] = True | |
while True: | |
await ping_conn.execute("SELECT 0") | |
await self.monitor_sleep() | |
except Exception as error: | |
if self.status_dict[url] == True: | |
logger.info(f"failure: {url} :: {error}") | |
self.status_dict[url] = False | |
await self.monitor_sleep() | |
async def monitor_sleep(self) -> None: | |
"introduce monitor delay" | |
await asyncio.sleep(self.monitor_period) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment