Last active
December 29, 2023 11:24
-
-
Save surenkov/291fbd0bdd2638bc19ea6b6d7c84e3e3 to your computer and use it in GitHub Desktop.
Postgres table / advisory lock context managers for Django python apps
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import enum | |
import hashlib | |
import typing as t | |
from django.apps import apps | |
from django.db import models, transaction, connection, OperationalError, InternalError | |
from psycopg2 import sql | |
__all__ = ( | |
"table_lock", | |
"advisory_lock", | |
"LockMode", | |
"LockNowaitError", | |
"LockTimeoutError", | |
) | |
class LockMode(str, enum.Enum): | |
AccessShare = "ACCESS SHARE" | |
RowShare = "ROW SHARE" | |
RowExclusive = "ROW EXCLUSIVE" | |
ShareUpdateExclusive = "SHARE UPDATE EXCLUSIVE" | |
Share = "SHARE" | |
ShareRowExclusive = "SHARE ROW EXCLUSIVE" | |
Exclusive = "EXCLUSIVE" | |
AccessExclusive = "ACCESS EXCLUSIVE" | |
class LockTimeoutError(OperationalError): | |
""" Issued when lock timeout has been reached """ | |
class LockNowaitError(LockTimeoutError): | |
""" Issued while trying to lock already locked table with nowait set to True""" | |
class TableLock(contextlib.ContextDecorator): | |
def __init__(self, *tables: str, using: str = None, mode: LockMode = None, nowait: bool = False, timeout: int = None): | |
if nowait and timeout: | |
raise ValueError("Can't set both nowait and timeout options") | |
if timeout == 0: | |
timeout, nowait = None, True # type: ignore | |
elif timeout is not None and timeout <= 0: | |
raise ValueError("Lock timeout should be a positive integer") | |
self.tables = tables | |
self.using = using | |
self.mode = mode | |
self.nowait = nowait | |
self.timeout = timeout | |
@classmethod | |
def for_model(cls, *models: t.Union[str, t.Type[models.Model]], **kwargs): | |
all_models = (apps.get_model(_) if isinstance(_, str) else _ for _ in models) | |
tables = (model._meta.db_table for model in all_models) | |
return cls(*tables, **kwargs) | |
def __enter__(self): | |
query, params = self._prepare_lock_query() | |
self._exec_lock_query(query, params) | |
def __exit__(self, *exc_info): | |
return False | |
def _exec_lock_query(self, lock_query: sql.SQL, params: tuple[t.Any, ...]): | |
conn = transaction.get_connection(self.using) | |
with conn.cursor() as c: | |
try: | |
c.execute(lock_query, params) | |
except OperationalError as e: | |
if self.nowait: | |
raise LockNowaitError from e | |
if self.timeout is not None: | |
raise LockTimeoutError(self.timeout) from e | |
raise | |
def _prepare_lock_query(self): | |
empty, params, mode = sql.SQL(""), (), self.mode | |
query = "LOCK TABLE {tables} {lock_mode} {nowait}" | |
query = sql.SQL(query).format( | |
tables=sql.SQL(", ").join(map(sql.Identifier, self.tables)), | |
lock_mode=(empty if mode is None else sql.SQL(f"IN {mode.value} MODE")), | |
nowait=(sql.SQL("NOWAIT") if self.nowait else empty), | |
) | |
if self.timeout is not None: | |
query = sql.SQL("SET LOCAL lock_timeout = %s; {}").format(query) | |
params = (f"{self.timeout}s",) | |
return query, params | |
class AdvisoryLock(contextlib.ContextDecorator): | |
_transaction: bool = False | |
_connection = connection | |
def __init__(self, *lock_id: t.Union[int, str], using: str = None, shared: bool = False, nowait: bool = False, timeout: int = None): | |
if nowait and timeout: | |
raise ValueError("Can't set both nowait and timeout options") | |
if timeout == 0: | |
timeout, nowait = None, True # type: ignore | |
elif timeout is not None and timeout < 0: | |
raise ValueError("Timeout should be a positive integer") | |
self.lock_id = self._validate_lock_id(lock_id) | |
self.using = using | |
self.shared = shared | |
self.nowait = nowait | |
self.timeout = timeout | |
def __enter__(self): | |
self._connection = conn = transaction.get_connection(self.using) | |
self._transaction = conn.in_atomic_block | |
if self.timeout is not None and not self._transaction: | |
raise InternalError("Timeout can only be used in transaction block") | |
lock_func = self._prepare_lock_function() | |
with conn.cursor() as c: | |
lock_func.exec(c) | |
def __exit__(self, *exc_info): | |
if (unlock_func := self._prepare_unlock_function()) is not None: | |
with self._connection.cursor() as c: | |
unlock_func.exec(c) | |
self._connection = connection | |
self._transaction = False | |
return False | |
def _prepare_lock_function(self) -> "_AdvisoryLockFunction": | |
expr = sql.SQL("pg_") | |
ctor, args = _AdvisoryLockFunction, {} | |
if self.nowait: | |
ctor = _NowaitAdvisoryLockFuncion | |
expr += sql.SQL("try_") | |
elif self.timeout is not None: | |
ctor = _TimeoutAdvisoryLockFunction | |
args["timeout"] = self.timeout | |
expr += sql.SQL("advisory_xact_lock" if self._transaction else "advisory_lock") | |
if self.shared: | |
expr += sql.SQL("_shared") | |
return ctor(expr, self.lock_id, **args) | |
def _prepare_unlock_function(self) -> t.Optional["_AdvisoryLockFunction"]: | |
if self._transaction: | |
return None | |
expr = sql.SQL("pg_advisory_unlock") | |
if self.shared: | |
expr += sql.SQL("_shared") | |
return _AdvisoryLockFunction(expr, self.lock_id) | |
def _validate_lock_id(self, lock_id) -> t.Union[tuple[int], tuple[int, int]]: | |
""" Since ``pg_advisory_lock`` functions family expects either a single bigint | |
or a pair of ints, application-specific IDs consisting of strings | |
or large ints should be transformed either by clamping or hashing. | |
I think the latter approach is better in a sense of avoiding collisions. | |
""" | |
def prep_id_part(id_val: t.Union[int, str, bytes], bit_len: t.Literal[32, 64]): | |
if isinstance(id_val, (str, t.ByteString)): | |
if isinstance(id_val, str): | |
id_val = id_val.encode() | |
digest = hashlib.md5(id_val, usedforsecurity=False).digest() | |
id_hash = int.from_bytes(digest, "little", signed=True) | |
elif isinstance(id_val, int): | |
id_hash = id_val | |
else: | |
raise ValueError(f"Can't prepare lock id from {id_val}") | |
# Reduce hash length to conform pg_advisory_lock signature | |
# (64 OR (32, 32) bit signed int) | |
hash_len = id_hash.bit_length() | |
while hash_len > bit_len: | |
hash_len >>= 1 | |
hi_mask = (1 << hash_len) - 1 | |
lo_mask = hi_mask << hash_len | |
id_hash = (id_hash & hi_mask) ^ ((id_hash & lo_mask) >> hash_len) | |
# Convert back to signed int | |
signed_len = bit_len - 1 | |
return (id_hash & ((1 << signed_len) - 1)) - (id_hash & (1 << signed_len)) | |
if len(lock_id) == 1: | |
valid_lock_id = (prep_id_part(lock_id[0], 64),) | |
elif len(lock_id) == 2: | |
a, b = lock_id | |
valid_lock_id = prep_id_part(a, 32), prep_id_part(b, 32) | |
else: | |
raise ValueError(f"Unsupported number of arguments: {lock_id}") | |
return valid_lock_id | |
class _AdvisoryLockFunction: | |
__slots__ = "func", "lock_id" | |
def __init__( | |
self, | |
func: sql.Composable, | |
lock_id: t.Union[tuple[int], tuple[int, int]], | |
*args, | |
**kwargs, | |
): | |
self.func = func | |
self.lock_id = lock_id | |
def get_lock_expr(self): | |
expr_params = sql.SQL(", ").join((sql.Placeholder(),) * len(self.lock_id)) | |
expr = self.func + sql.SQL("({})").format(expr_params) | |
return sql.SQL("SELECT {}").format(expr) | |
def exec(self, cursor): | |
cursor.execute(self.get_lock_expr(), self.lock_id) | |
class _NowaitAdvisoryLockFuncion(_AdvisoryLockFunction): | |
__slots__ = () | |
def exec(self, cursor): | |
cursor.execute(self.get_lock_expr(), self.lock_id) | |
if not cursor.fetchone()[0]: | |
raise LockNowaitError(self.lock_id) | |
class _TimeoutAdvisoryLockFunction(_AdvisoryLockFunction): | |
__slots__ = "timeout" | |
def __init__( | |
self, | |
func: sql.Composable, | |
lock_id: t.Union[tuple[int], tuple[int, int]], | |
timeout: int, | |
*args, | |
**kwargs, | |
): | |
super().__init__(func, lock_id, *args, **kwargs) | |
self.timeout = timeout | |
def exec(self, cursor): | |
cursor.execute("SET LOCAL lock_timeout = %s", [f"{self.timeout}s"]) | |
try: | |
cursor.execute(self.get_lock_expr(), self.lock_id) | |
except OperationalError as e: | |
raise LockTimeoutError(self.lock_id, self.timeout) from e | |
table_lock = TableLock | |
advisory_lock = AdvisoryLock |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from threading import Thread | |
from django.apps import apps | |
from django.db import transaction, InternalError | |
from django_pg_locks import table_lock, advisory_lock, LockNowaitError, LockTimeoutError | |
@pytest.mark.django_db | |
def test_table_lock(reraise): | |
def raise_nowait_err(): | |
with reraise, pytest.raises(LockNowaitError): | |
with transaction.atomic(), table_lock("flight", nowait=True): | |
assert False, "Should not reach there" | |
with transaction.atomic(), table_lock("flight"): | |
thread = Thread(target=raise_nowait_err) | |
thread.start() | |
thread.join(timeout=3) | |
@pytest.mark.django_db | |
def test_model_table_lock(reraise): | |
Flight = apps.get_model("flight", "Flight") | |
def raise_nowait_err(): | |
with reraise, pytest.raises(LockNowaitError): | |
with transaction.atomic(), table_lock.for_model(Flight, "flight.Airport", nowait=True): | |
assert False, "Should not reach there" | |
with transaction.atomic(), table_lock.for_model(Flight, "flight.Airport"): | |
thread = Thread(target=raise_nowait_err) | |
thread.start() | |
thread.join(timeout=3) | |
@pytest.mark.django_db | |
def test_table_lock_timeout(reraise): | |
with pytest.raises(ValueError): | |
table_lock("asdf", nowait=True, timeout=10) | |
table_lock("asdf", timeout=-1) | |
def raise_flight_lock_timeout_err(): | |
with reraise, pytest.raises(LockTimeoutError): | |
with transaction.atomic(), table_lock("flight", timeout=1): | |
assert False, "Should not reach there" | |
def raise_flight_lock_nowait_err(): | |
with reraise, pytest.raises(LockNowaitError): | |
with transaction.atomic(), table_lock("flight", timeout=0): | |
assert False, "Should not reach there" | |
with transaction.atomic(), table_lock("flight"): | |
threads = [] | |
for func in [raise_flight_lock_timeout_err, raise_flight_lock_nowait_err]: | |
thread = Thread(target=func) | |
thread.start() | |
threads.append(thread) | |
for thread in threads: | |
thread.join(timeout=3) | |
@pytest.mark.django_db | |
def test_advisory_lock(reraise): | |
def raise_nowait_err(): | |
with reraise, pytest.raises(LockNowaitError), advisory_lock(1, nowait=True): | |
assert False, "Should not reach there" | |
def lock_passes(): | |
with reraise, advisory_lock(2, nowait=True): | |
pass | |
with advisory_lock(1): | |
thread = Thread(target=raise_nowait_err) | |
another = Thread(target=lock_passes) | |
thread.start() | |
another.start() | |
thread.join(timeout=3) | |
another.join(timeout=3) | |
@pytest.mark.django_db | |
def test_advisory_lock_timeout(reraise): | |
with pytest.raises(ValueError): | |
advisory_lock("asdf", 1, nowait=True, timeout=10) | |
advisory_lock("asdf", 1, timeout=-1) | |
def raise_timeout_err(): | |
with reraise, pytest.raises(LockTimeoutError): | |
with transaction.atomic(), advisory_lock("asdf", 1, timeout=1): | |
assert False, "Should not reach here" | |
def raise_nowait_err(): | |
with reraise, pytest.raises(LockNowaitError): | |
with transaction.atomic(), advisory_lock("asdf", 1, timeout=0): | |
assert False, "Should not reach here" | |
def raise_internal_err(): | |
with reraise, pytest.raises(InternalError): | |
with advisory_lock("asdf", 2, timeout=1): | |
assert False, "Should not be run outside of transaction" | |
with advisory_lock("asdf", 1): | |
threads = [] | |
for func in [raise_timeout_err, raise_nowait_err, raise_internal_err]: | |
thread = Thread(target=func) | |
thread.start() | |
threads.append(thread) | |
for thread in threads: | |
thread.join(timeout=3) | |
@pytest.mark.parametrize(["lock_id"], [ | |
[(1 << 63 - 1,)], | |
[(1 << 63,)], | |
[(1 << 64,)], | |
[(1 << 65,)], | |
[(1 << 31 - 1, 1 << 31)], | |
[(1 << 32, 1 << 33)], | |
]) | |
@pytest.mark.django_db | |
def test_large_lock_ids(lock_id): | |
with advisory_lock(*lock_id): | |
pass | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment