Last active
May 6, 2021 05:46
-
-
Save Flushot/f81e1f0db479d115d491a3e034f5b91d to your computer and use it in GitHub Desktop.
Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API easier to deal with
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
PostgreSQL database utilities. | |
Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API | |
easier to deal with. | |
""" | |
from collections import defaultdict | |
import contextlib | |
import logging | |
import re | |
import threading | |
from typing import Any, Dict, Generator, Type, List, Optional, Tuple, Union # Generator[yields, emits, returns] | |
import psycopg2 | |
import psycopg2.extensions | |
import psycopg2.pool | |
import psycopg2.extras | |
from tqdm.auto import tqdm | |
from deprecated import deprecated | |
LOG = logging.getLogger(__name__) | |
FORMAT_SQL_QUERY_PAT = re.compile(r'(?:\s|\t|\n){2,}') | |
_default_pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None | |
_default_pool_args: Optional[Tuple] = None | |
# TODO: rename to configure_default_pool | |
def configure_pool(*args, **kwargs) -> None: | |
""" | |
Configure default connection pool options. These options will be set lazily, | |
then used when get_pool() is called for the first time. | |
If your application needs to use the default connection pool (i.e. calls get_pool()) | |
then you MUST call this function during app initialization, before get_pool(). | |
For available pool args/kwargs, see: | |
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS | |
:param args: pool args. | |
:param kwargs: pool keyword args. | |
""" | |
global _default_pool_args | |
if _default_pool_args is not None: | |
LOG.warning('configure_pool() should only be called once') | |
_default_pool_args = (args, kwargs) | |
# TODO: rename to get_default_pool | |
def get_pool() -> psycopg2.pool.ThreadedConnectionPool: | |
""" | |
Get the default database connection pool. | |
This pool is the default connection pool used by all functions in this module, | |
where a connection isn't explicitly passed as an argument. | |
You generally don't need to call this directly to get connections, and should | |
almost always use the connection() context manager instead (so that your connection | |
is automatically returned to the pool when you're finished). | |
:return: default connection pool. | |
""" | |
global _default_pool | |
if _default_pool is None: | |
if _default_pool_args is None: | |
raise RuntimeError('You must call configure_pool() before get_pool()') | |
_default_pool = psycopg2.pool.ThreadedConnectionPool(*_default_pool_args[0], | |
**_default_pool_args[1]) | |
return _default_pool | |
def get_conn(pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None, | |
retry_limit: Optional[int] = 5) -> \ | |
Tuple[psycopg2.extensions.connection, psycopg2.pool.ThreadedConnectionPool]: | |
""" | |
Check out a database connection from a pool and tests it to ensure it's working. | |
If the connection is bad, it will be discarded and a new connection will be attempted | |
up to retry_limit times. | |
You generally don't need to call this directly to get connections, and should | |
almost always use the connection() context manager instead (so that your connection | |
is automatically returned to the pool when you're finished). | |
:param pool: optional connection pool to use (if unspecified, the default pool will be used). | |
:param retry_limit: max number of times to get a connection from the pool if it is bad. | |
if None, the limit is infinite. | |
:return: tuple of: | |
- connection that was fetched from the pool. | |
- pool the connection came from. | |
""" | |
# TODO: Handle PoolError('connection pool exhausted') with optional blocking | |
if pool is None: | |
pool = get_pool() | |
@retry((psycopg2.Error,), retry_limit) | |
def try_get_conn() -> psycopg2.extensions.connection: | |
conn = pool.getconn() | |
# Test connection to ensure it's alive | |
cur = conn.cursor() | |
try: | |
cur.execute('select 1') | |
except psycopg2.Error: | |
# Connection is bad: Return it to the pool to be discarded. | |
if conn is not None: | |
pool.putconn(conn, close=True) | |
if not cur.closed: | |
cur.close() | |
return conn | |
return try_get_conn(), pool | |
@contextlib.contextmanager | |
def connection(pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None) -> \ | |
Generator[psycopg2.extensions.connection, None, None]: | |
""" | |
Context manager that returns a connection, then cleans it up and returns it to the pool | |
when finished. Use this instead of get_conn(). | |
:param pool: optional connection pool to use (will fallback to default pool if unspecified). | |
:yield: connection. | |
""" | |
conn, pool = get_conn(pool) | |
try: | |
yield conn | |
finally: | |
if conn is not None: | |
pool.putconn(conn) | |
@contextlib.contextmanager | |
def _ensure_connection(conn: Optional[psycopg2.extensions.connection] = None) -> \ | |
Generator[psycopg2.extensions.connection, None, None]: | |
""" | |
Internal context manager that will ensure the block is supplied with a user-defined or default | |
connection. | |
:param conn: optional user-defined connection (if unspecified, default connection will be used). | |
:yield: connection. | |
""" | |
if conn: | |
yield conn | |
else: | |
with connection() as conn: | |
yield conn | |
class cursor: | |
""" | |
Context manager for database cursor that has the following behavior: | |
- Yields a DictCursor that returns dict-like rows. | |
- Handles transaction behavior: | |
- Commits transaction upon exit (or rolls back if there was an exception). | |
- When nested (and when a user-defined connection is passed), supports nested | |
transaction behavior using savepoints. | |
- Closes the cursor when finished. | |
""" | |
use_savepoints: bool = True | |
class ConnOpts: | |
""" | |
Per-connection options. | |
""" | |
def __init__(self): | |
self.nest_level = 0 # Nested transaction level | |
self.thread_lock = threading.Lock() # Thread-safe access to vars | |
_conn_opts: Dict[psycopg2.extensions.connection, ConnOpts] = defaultdict(ConnOpts) | |
def __init__(self, | |
conn: Optional[psycopg2.extensions.connection] = None, | |
pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None, | |
**cursor_args): | |
""" | |
:param conn: optional connection (if unspecified, default connection will be used). | |
:param pool: optional connection pool (if conn is None and you need to pass a custom pool). | |
""" | |
if conn is not None and pool is not None: | |
raise ValueError('conn and pool are mutually exclusive') | |
if conn is None: | |
conn, pool = get_conn(pool) | |
self._is_managed_connection = True | |
else: | |
self._is_managed_connection = False | |
self._conn = conn | |
self._pool = pool | |
self._opts = self._conn_opts[conn] # TODO: convert keys to weakref.ref() | |
self._cursor_args = cursor_args | |
if 'cursor_factory' not in self._cursor_args: | |
self._cursor_args['cursor_factory'] = psycopg2.extras.DictCursor | |
def __enter__(self) -> psycopg2.extensions.cursor: | |
""" | |
:return: cursor. | |
""" | |
with self._opts.thread_lock: | |
self._opts.nest_level += 1 | |
self._cur = self._conn.cursor(**self._cursor_args) | |
if self._cur.closed: | |
raise RuntimeError('Connection returned a closed cursor') | |
return self._cur | |
def __exit__(self, | |
exc_type: Optional[Type[BaseException]], | |
exc_val: Exception, | |
exc_tb): | |
with self._opts.thread_lock: | |
nest_level = self._opts.nest_level | |
try: | |
if exc_type is None: | |
# Success | |
if not self._is_managed_connection and nest_level > 1: | |
# Nested transaction (create new savepoint) | |
if self.use_savepoints: | |
with _ensure_cursor(self._cur) as cur: | |
cur.execute(f'savepoint level_{nest_level}') | |
else: | |
# Topmost transaction (commit transaction) | |
self._conn.commit() | |
else: | |
# Failure | |
LOG.error(f'Rolling back {"transaction" if nest_level > 2 else "to previous savepoint"} because of {exc_type} error: {exc_val!r}', exc_info=exc_val) | |
if not self.use_savepoints or nest_level <= 2: | |
# Level 1 or 2 transaction (rollback transaction; no previous savepoint to rollback to) | |
try: | |
self._conn.rollback() | |
except psycopg2.Error as ex: | |
LOG.error(f'Rollback failed because of error: {ex!r}', exc_info=True) | |
elif self.use_savepoints and not self._is_managed_connection: | |
# Nested transaction (roll back to previous savepoint) | |
try: | |
with cursor(conn=self._conn) as cur: # Create new cursor (current may now be invalid) | |
cur.execute(f'rollback to savepoint level_{nest_level - 1}') | |
except psycopg2.Error as ex: | |
# Compound failure (rollback entire transaction to be safe) | |
LOG.error(f'Falling back to transaction rollback because savepoint rollback failed: {ex!r}', exc_info=True) | |
try: | |
self._conn.rollback() | |
except psycopg2.Error as ex: | |
LOG.error(f'Fallback rollback failed because of error: {ex!r}', exc_info=True) | |
finally: | |
# Cleanup | |
if not self._cur.closed: | |
self._cur.close() | |
with self._opts.thread_lock: | |
self._opts.nest_level -= 1 | |
if self._is_managed_connection and self._conn is not None: | |
self._pool.putconn(self._conn) | |
return False # Re-raise exceptions | |
@contextlib.contextmanager | |
def _ensure_cursor(cur: Optional[psycopg2.extensions.cursor] = None, **cursor_args) -> \ | |
Generator[psycopg2.extensions.cursor, None, None]: | |
""" | |
Internal context manager that will ensure the block is supplied with a user-defined or default | |
cursor. | |
:param cursor: optional user-defined cursor (if unspecified, default cursor/connection will be used). | |
:yield: cursor. | |
""" | |
if cur: | |
yield cur | |
else: | |
with cursor(**cursor_args) as cur: | |
yield cur | |
def fetchmany(statement: str, | |
params: Optional[Tuple] = None, | |
use_tqdm: Union[bool, dict] = False, | |
cur: Optional[psycopg2.extensions.cursor] = None) -> \ | |
Generator[psycopg2.extras.DictRow, None, int]: | |
""" | |
Executes SQL and returns multi-row results. | |
Example: | |
for row in fetchmany('select id, name from foo where bar = %s', (some_var,)): | |
print(row['id']) | |
:param statement: SQL statement. | |
:param params: SQL statement parameters. | |
:param use_tqdm: whether ot not to use tqdm progress bar (can also be an options dict for tqdm). | |
:param cur: optional user-defined cursor. | |
:yield: row. | |
:return: row count. | |
""" | |
with _ensure_cursor(cur) as cur: | |
query = cur.mogrify(statement, params) | |
LOG.debug(f'SQL query: {_format_sql_query(query)}') | |
cur.execute(query) | |
if (use_tqdm is True or isinstance(use_tqdm, dict)) and cur.rowcount > 0: | |
# Show progress bar | |
tqdm_opts = {} | |
if isinstance(use_tqdm, dict): | |
tqdm_opts = use_tqdm | |
progress = tqdm(total=cur.rowcount, **tqdm_opts) | |
else: | |
# Hide progress bar | |
progress = None | |
while True: | |
rows = cur.fetchmany(cur.arraysize) | |
if len(rows) == 0: | |
break | |
for row in rows: | |
yield row | |
if progress is not None: | |
progress.update() | |
if progress is not None: | |
progress.close() | |
return cur.rowcount | |
def fetchone(statement: str, | |
params: Optional[Tuple] = None, | |
cur: Optional[psycopg2.extensions.cursor] = None) -> psycopg2.extras.DictRow: | |
""" | |
Execute SQL and returns single row result. | |
:param statement: SQL statement. | |
:param params: SQL statement parameters. | |
:param cur: optional user-defined cursor. | |
:return: row. | |
""" | |
with _ensure_cursor(cur) as cur: | |
query = cur.mogrify(statement, params) | |
LOG.debug(f'SQL query: {_format_sql_query(query)}') | |
cur.execute(query) | |
return cur.fetchone() | |
def execute(statement: str, | |
params: Optional[Union[Tuple, Dict]] = None, | |
cur: Optional[psycopg2.extensions.cursor] = None) -> None: | |
""" | |
Execute a database operation (query or command). | |
Parameters may be provided as sequence or mapping and will be bound to variables in the operation. | |
Variables are specified either with positional (%s) or named (%(name)s) placeholders. | |
The method returns None. If a query was executed, the returned values can be retrieved using | |
fetch*() methods. | |
:param statement: SQL statement. | |
:param params: SQL statement parameters. | |
:param cur: optional user-defined cursor. | |
""" | |
with _ensure_cursor(cur) as cur: | |
query = cur.mogrify(statement, params) | |
LOG.debug(f'SQL query: {_format_sql_query(query)}') | |
cur.execute(query) | |
def execute_values(statement: str, | |
values: List[Tuple], | |
template: Optional[str] = None, | |
page_size: Optional[int] = None, | |
fetch: bool = False, | |
cur: Optional[psycopg2.extensions.cursor] = None) -> \ | |
Generator[psycopg2.extras.DictRow, None, int]: | |
""" | |
Execute a statement using VALUES with a sequence of parameters. | |
:param statement: SQL statement to execute. It must contain a single %s placeholder, which will | |
be replaced by a VALUES list. | |
Example: "INSERT INTO mytable (id, f1, f2) VALUES %s". | |
:param values: sequence of sequences or dictionaries with the arguments to send to the query. | |
The type and content must be consistent with template. | |
:param template: the snippet to merge to every item in argslist to compose the query. | |
- If the argslist items are sequences it should contain positional placeholders | |
(e.g. "(%s, %s, %s)", or "(%s, %s, 42)” if there are constants value…). | |
- If the argslist items are mappings it should contain named placeholders | |
(e.g. "(%(id)s, %(f1)s, 42)"). | |
If not specified, assume the arguments are sequence and use a simple positional template | |
(i.e. (%s, %s, ...)), with the number of placeholders sniffed by the first element in argslist. | |
:param page_size: maximum number of argslist items to include in every statement. | |
If there are more items the function will execute more than one statement. | |
Defaults to the length of the values parameter. | |
:param fetch: if True return the query results into a list (like in a fetchall()). | |
Useful for queries with RETURNING clause. | |
:param cur: optional user-defined cursor. | |
:yield: row (if fetch parameter is True). | |
:return: row count (if fetch parameter is True). | |
""" | |
# TODO: Add tqdm_opts parameter like with fetchmany() | |
if page_size is None: | |
page_size = len(values) | |
with _ensure_cursor(cur) as cur: | |
LOG.debug(f'SQL query: {_format_sql_query(statement.encode("utf-8"))} -> {values!r}') | |
result = psycopg2.extras.execute_values(cur, | |
statement, | |
values, | |
template=template, | |
page_size=page_size, | |
fetch=fetch) | |
if fetch: | |
for row in result: | |
yield row | |
return cur.rowcount | |
def upsert(row: Dict[str, Any], | |
table_name: str, | |
primary_key: List[str], | |
include_keys: Optional[List[str]] = None, | |
exclude_keys: Optional[List[str]] = None, | |
cur: Optional[psycopg2.extensions.cursor] = None) -> Any: | |
if include_keys is None: | |
include_keys = row.keys() | |
if exclude_keys is None: | |
exclude_keys = [] | |
item_keys = [k for k in include_keys if k not in exclude_keys] | |
with _ensure_cursor(cur) as cur: | |
return fetchone( | |
f''' | |
insert into {table_name} ({', '.join(item_keys)}) | |
values ({', '.join(['%s' for _ in item_keys])}) | |
on conflict ({', '.join(primary_key)}) do update set | |
{', '.join([f'{k} = excluded.{k}' for k in item_keys if k not in primary_key])} | |
returning * | |
''', | |
tuple([row.get(key) for key in item_keys]), | |
cur=cur) | |
def _format_sql_query(query: bytes) -> str: | |
""" | |
Strips extra whitespace and newlines from SQL queries, so that they are easier to read in logs. | |
:param query: query to format. | |
:return: formatted query. | |
""" | |
return FORMAT_SQL_QUERY_PAT.sub(' ', normalize_line_endings(query.decode('utf-8'))).strip() | |
def normalize_line_endings(s: str) -> str: | |
""" | |
Converts various line ending characters/pairs into \n | |
:param s: string with possibly abnormal line endings. | |
:return: normalized string. | |
""" | |
return s.replace('\r\n', '\n').replace('\r', '\n') | |
def retry(exc_types: Sequence[Type], | |
max_attempts: Optional[int] = None, | |
delay: int = 0, | |
error_fn: Optional[Callable[[BaseException], None]] = None) -> Callable: | |
""" | |
Decorator that automatically re-calls a function if it throws a set of expected exception types. | |
:param exc_types: exception classes to retry on. | |
:param max_attempts: max number of attempts to retry before re-throwing. | |
if None, there is no limit. | |
:param delay: optional time delay between retry attempts (in seconds). | |
:param error_fn: optional function to call (with exception) when an error occurs. | |
""" | |
def retry_decorator(f: Callable) -> Callable: | |
def retryable_func(*args, **kwargs): | |
for attempt in range(max_attempts): | |
try: | |
return f(*args, **kwargs) | |
except tuple(exc_types) as ex: | |
if error_fn is not None: | |
error_fn(ex) | |
if attempt >= max_attempts: | |
raise | |
LOG.warning(f'Retrying because of {ex.__class__.__name__} error: {ex!r}') | |
if delay > 0: | |
time.sleep(delay) | |
return functools.wraps(f)(retryable_func) | |
return retry_decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment