|
from sqlalchemy.pool import QueuePool |
|
from sqlalchemy import create_engine |
|
from contextlib import closing, contextmanager, ExitStack |
|
from sqlalchemy.sql import text |
|
import asyncio |
|
import logging |
|
import time |
|
|
|
|
|
class AsyncMysqlDatabase: |
|
def __init__(self, dsn, executor=None, hooks=None, pool_size=10): |
|
self.engine = create_engine( |
|
dsn, poolclass=QueuePool, pool_size=pool_size, max_overflow=0 |
|
) |
|
self.executor = executor |
|
self.hooks = hooks or [self.logging_hook] |
|
|
|
@property |
|
def loop(self): |
|
return asyncio.get_event_loop() |
|
|
|
async def connect(self, loop=None): |
|
(loop or self.loop).run_in_executor(self.executor, self.engine.connect) |
|
|
|
async def dispose(self, loop=None): |
|
(loop or self.loop).run_in_executor(self.executor, self.engine.dispose) |
|
|
|
def sync_fetch_all(self, query, params): |
|
with self._hook_scope(query=query, params=params): |
|
with closing(self.engine.execute(text(query), **params)) as results: |
|
results = [dict(result) for result in results] |
|
return results |
|
|
|
def sync_fetch_one(self, query, params): |
|
with self._hook_scope(query=query, params=params): |
|
with closing(self.engine.execute(text(query), **params)) as result: |
|
result = result.fetchone() |
|
if result is not None: |
|
result = dict(result) |
|
return result |
|
|
|
async def fetch_all(self, query, params, loop=None): |
|
results = await (loop or self.loop).run_in_executor( |
|
None, self.sync_fetch_all, query, params |
|
) |
|
return results |
|
|
|
async def fetch_one(self, query, params, default=None, loop=None): |
|
result = await (loop or self.loop).run_in_executor( |
|
None, self.sync_fetch_one, query, params |
|
) |
|
|
|
if default is not None: |
|
result = result or {} |
|
result = {k: result.get(k) or default[k] for k in default} |
|
return result |
|
|
|
@contextmanager |
|
def logging_hook(self, query, params, **kwargs_dump): |
|
start = time.time() |
|
try: |
|
yield |
|
except: |
|
took = (time.time() - start) * 1000.0 |
|
logging.exception( |
|
{ |
|
"message": "Error running mysql query", |
|
"query": query, |
|
"params": params, |
|
"took_ms": took, |
|
} |
|
) |
|
raise |
|
else: |
|
took = (time.time() - start) * 1000.0 |
|
logging.debug( |
|
{ |
|
"message": "Success running mysql query", |
|
"query": query, |
|
"params": params, |
|
"took_ms": took, |
|
} |
|
) |
|
|
|
@contextmanager |
|
def _hook_scope(self, **kwargs): |
|
with ExitStack() as cm: |
|
for hook in self.hooks: |
|
cm.enter_context(hook(**kwargs)) |
|
yield |