Skip to content

Instantly share code, notes, and snippets.

@Andrei-Pozolotin
Last active January 25, 2020 16:03
Show Gist options
  • Save Andrei-Pozolotin/f5be7dec56840428c6245a1cef4a25eb to your computer and use it in GitHub Desktop.
Save Andrei-Pozolotin/f5be7dec56840428c6245a1cef4a25eb to your computer and use it in GitHub Desktop.
"""
robust cluster connection pool
"""
import asyncio
import logging
from contextvars import ContextVar
from typing import Awaitable
from typing import List
from typing import Mapping
import asyncpg
import funcy
logger = logging.getLogger(__name__)
override = lambda func : func
"global cluster pool"
POOL_DBMS:ContextVar["SmaprtPoolDMBS"] = ContextVar('POOL_DBMS', default=None)
class SmaprtPoolDMBS(asyncpg.pool.Pool):
"dbms connection pool with automatic failover"
# current connection
conn_url:str = None
# initialization blocker
conn_block:asyncio.Condition = None
# connection state: url -> has-live
status_dict:Mapping[str, bool] = None
# track internal jobs for cancel
monitor_task_list:List[asyncio.Task] = None
monitor_period:float = None
@override
def __init__(self,
# cluster members, main first
url_list:List[str],
# dbms open connect timeout, seconds
connect_timeout:float=3,
# period for pool monitor tasks
monitor_period:float=1,
) -> None:
self.conn_url = None
self.connect_timeout = connect_timeout
self.monitor_period = monitor_period
self.conn_block = asyncio.locks.Condition()
self.monitor_task_list = list()
self.status_dict = dict([(url, False) for url in url_list])
connect_kwargs = dict(
timeout=connect_timeout,
)
# TODO config
super().__init__(
None,
min_size=1,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300,
setup=None,
init=None,
loop=None,
connection_class=asyncpg.connection.Connection,
**connect_kwargs,
)
@override
async def _async__init__(self) -> "PoolDMBS":
"setup cluster connection monitor"
if self._initialized:
return self
self.spawn_task(self.monitor_main(), name="pool-main")
await self.block_conn_url()
return await super()._async__init__()
@override
async def _acquire(self, timeout:float) -> "PoolConnectionProxy":
"setup cluster connection proxy"
await self.block_conn_url()
return await super()._acquire(timeout=timeout)
@override
async def close(self):
"terminate monitor tasks"
await self.monitor_close()
await super().close()
@override
def terminate(self):
"terminate monitor tasks"
self.monitor_terminate()
super().terminate()
def spawn_task(self, coro:Awaitable, name:str=None) -> asyncio.Task:
"register pool monitor task"
pool_task = asyncio.create_task(coro, name=name)
self.monitor_task_list.append(pool_task)
return pool_task
def has_conn_url(self) -> bool:
"check if connection is available"
return self.conn_url is not None
async def block_conn_url(self):
"block till connection is available"
if not self.has_conn_url():
async with self.conn_block:
await self.conn_block.wait_for(self.has_conn_url)
def active_conn_url(self) -> str:
"extract current connection url"
status_dict = funcy.filter(
lambda status_item : status_item[1] == True,
self.status_dict.items()
)
status_item = funcy.first(status_dict)
if status_item:
return status_item[0]
else:
return None
async def monitor_close(self) -> None:
"terminate monitor tasks"
self.monitor_terminate()
await asyncio.gather(*self.monitor_task_list, return_exceptions=True)
def monitor_terminate(self) -> None:
"terminate monitor tasks"
for monitor_task in self.monitor_task_list:
monitor_task.cancel()
async def monitor_main(self) -> None:
"track connection state for all cluster members"
for url in self.status_dict.keys():
self.spawn_task(self.monitor_ping(url), name="pool-ping")
while True:
conn_url = self.active_conn_url()
if conn_url != self.conn_url:
logger.info(f"activate: {conn_url}")
self.set_connect_args(conn_url)
await self.expire_connections()
async with self.conn_block:
self.conn_url = conn_url
self.conn_block.notify_all()
await self.monitor_sleep()
async def monitor_ping(self, url:str) -> None:
"track single member connection live/dead state"
while True:
try:
ping_conn = await asyncpg.connect(
url,
timeout=self.connect_timeout,
)
if self.status_dict[url] == False:
logger.info(f"success: {url}")
self.status_dict[url] = True
while True:
await ping_conn.execute("SELECT 0")
await self.monitor_sleep()
except Exception as error:
if self.status_dict[url] == True:
logger.info(f"failure: {url} :: {error}")
self.status_dict[url] = False
await self.monitor_sleep()
async def monitor_sleep(self) -> None:
"introduce monitor delay"
await asyncio.sleep(self.monitor_period)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment