Last active
April 2, 2021 01:36
-
-
Save Mulugruntz/b3fe059477d046904ff997c9e6119719 to your computer and use it in GitHub Desktop.
EdgeDB 1-beta1 skip nested transactions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A decorator to avoid WET code. | |
I happened to have to use the same function, sometimes | |
within a transaction, sometimes not. | |
However, I wanted my function to be executed in a | |
transaction (atomicity needed). | |
I didn't want to have to write twice the same function, | |
depending on whether it was inside a transaction or not. | |
This is the answer to this problem. | |
Of course, the best answer would be nested transactions! | |
""" | |
from functools import wraps | |
import aiostream | |
from contextlib import AsyncExitStack | |
from edgedb import AsyncIOConnection, AsyncIOPool | |
from edgedb.retry import AsyncIOIteration | |
def edge_con_or_tx(func): | |
@wraps(func) | |
async def inner(*args, **kwargs): | |
connection = None | |
arg_i = None | |
kwarg_k = None | |
for i, arg in enumerate(args): | |
if isinstance(arg, (AsyncIOConnection, AsyncIOIteration, AsyncIOPool)): | |
connection = arg | |
arg_i = i | |
break | |
else: | |
for k, v in kwargs.items(): | |
if isinstance(v, (AsyncIOConnection, AsyncIOIteration, AsyncIOPool)): | |
connection = v | |
kwarg_k = k | |
break | |
else: | |
raise TypeError( | |
f"Function {func.__name__} has no EdgeDB " | |
f"connection/transaction in its parameters." | |
) | |
if isinstance(connection, AsyncIOIteration): | |
xs = aiostream.stream.iterate([connection]) | |
else: | |
xs = aiostream.stream.preserve(connection.retrying_transaction()) | |
async with xs.stream() as streamer: | |
async for tx in streamer: | |
async with AsyncExitStack() as astack: | |
assert isinstance(tx, AsyncIOIteration) | |
if not tx.is_active(): | |
await astack.enter_async_context(tx) | |
if arg_i is not None: | |
nargs = list(args) | |
nargs[arg_i] = tx | |
else: | |
kwargs[kwarg_k] = tx | |
return await func(*nargs, **kwargs) | |
return inner |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from typing import AsyncContextManager | |
from contextlib import asynccontextmanager | |
from edge_nested_tx import edge_con_or_tx | |
DB_CONNECTION_PROPS = { | |
"host": "localhost", | |
"port": 5656, | |
"user": "user", | |
"password": "password", | |
"database": "edgedb", | |
} | |
@asynccontextmanager | |
async def db_connect() -> AsyncContextManager: | |
async with await edgedb.create_async_pool( | |
min_size=2, **DB_CONNECTION_PROPS | |
) as pool: | |
yield pool | |
@edge_con_or_tx | |
async def runme(connection): | |
"""These two statements are inside the same transaction.""" | |
print(await connection.query_one("SELECT 2+2")) | |
print(await connection.query_one("SELECT 3+3")) | |
async def run_client(): | |
async with db_connect() as connection: | |
async for tx in connection.retrying_transaction(): | |
async with tx: | |
# already inside a transaction, will reuse it | |
print(await connection.query_one("SELECT 1+1")) | |
await runme(tx) | |
# not inside a transaction, will create one seamlessly | |
await runme(connection) | |
if __name__ == "__main__": | |
asyncio.get_event_loop().run_until_complete(run_client()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment