Created
May 28, 2025 11:31
-
-
Save daniel-kristjansson/ff435e2bef90e9d86b3f82ca896d393d to your computer and use it in GitHub Desktop.
async datastore usage
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from contextlib import asynccontextmanager | |
from dataclasses import dataclass | |
from functools import wraps | |
from typing import AsyncIterator, Union, Iterable, Callable | |
from ddtrace import tracer | |
from google.api_core.exceptions import Aborted | |
from google.api_core.retry import AsyncRetry | |
from google.auth import default as auth_default | |
from google.auth.exceptions import RefreshError, DefaultCredentialsError | |
from google.cloud import datastore, datastore_v1 | |
from google.cloud.datastore.helpers import entity_from_protobuf, entity_to_protobuf | |
from google.cloud.datastore_v1 import types | |
from app.datastore.exceptions import DatastoreConnectionError, LeakyAbort | |
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG to capture all levels of logs | |
logger = logging.getLogger(__name__) | |
DATASTORE_DEFAULT_READ_ASYNC_RETRY = AsyncRetry( | |
initial=0.1, # Start with a 100 ms delay | |
maximum=1.0, # Don't wait more than 1 second between retries | |
multiplier=2.0, # Double the wait each retry | |
deadline=5.0, # Give up after 5 seconds total | |
) | |
DATASTORE_DEFAULT_WRITE_ASYNC_RETRY = AsyncRetry( | |
initial=0.1, # Start with a 100 ms delay | |
maximum=1.0, # Don't wait more than 1 second between retries | |
multiplier=2.0, # Double the wait each retry | |
deadline=5.0, # Give up after 5 seconds total | |
) | |
@dataclass | |
class Transaction: | |
transaction_id: bytes | |
closed: bool = False | |
class AsyncDatastore: | |
_default_client: Union['AsyncDatastore', None] = None | |
@classmethod | |
def get_client(cls) -> 'AsyncDatastore': | |
""" | |
Get the default datastore client. | |
:return: AsyncDatastore instance using default credentials and project id | |
:raise: DatastoreConnectionError if the client cannot be initialized | |
""" | |
if cls._default_client: | |
return cls._default_client | |
try: | |
datastore_credentials, datastore_project_id = auth_default() | |
base_client = datastore_v1.DatastoreAsyncClient(credentials=datastore_credentials) | |
cls._default_client = AsyncDatastore(base_client, datastore_project_id) | |
except (RefreshError, DefaultCredentialsError) as e: | |
logger.error("Failed to initialize datastore client %s", repr(e)) | |
raise DatastoreConnectionError() | |
return cls._default_client | |
@classmethod | |
async def close(cls): | |
if cls._default_client: | |
client = cls._default_client.client | |
cls._default_client = None | |
if client.transport: | |
await client.transport.close() | |
def __init__(self, client: datastore_v1.DatastoreAsyncClient, | |
project_id: str, | |
default_read_retry: AsyncRetry = DATASTORE_DEFAULT_READ_ASYNC_RETRY, | |
default_write_retry: AsyncRetry = DATASTORE_DEFAULT_WRITE_ASYNC_RETRY): | |
self.client = client | |
self.project_id = project_id | |
self.default_read_retry = default_read_retry | |
self.default_write_retry = default_write_retry | |
def key(self, *path_args, **kwargs): | |
kwargs["project"] = self.project_id | |
return datastore.Key(*path_args, **kwargs) | |
@asynccontextmanager | |
async def _transaction(self, retry=None): | |
""" | |
Don't use this directly. | |
Use the transaction_with_retries decorator instead as it will handle retrying, and unclosed transactions. | |
""" | |
request = types.datastore.BeginTransactionRequest(project_id=self.project_id) | |
response = await self.client.begin_transaction(request=request, retry=retry or self.default_read_retry) | |
transaction = Transaction(transaction_id=response.transaction) | |
try: | |
yield transaction | |
finally: | |
if not transaction.closed: | |
await self.async_rollback(transaction) | |
@tracer.wrap(name="datastore", resource="rollback") | |
async def async_rollback(self, transaction: Transaction, retry=None): | |
request = types.RollbackRequest(project_id=self.project_id, transaction=transaction.transaction_id) | |
resp = await self.client.rollback(request=request, retry=retry or self.default_write_retry) | |
transaction.closed = True | |
return resp | |
async def async_get(self, key, retry=None, transaction: Transaction | None = None) -> datastore.Entity | None: | |
entities = await self.async_multi_get([key], retry, transaction) | |
return entities[0] if entities and len(entities) > 0 else None | |
async def async_put(self, entity: datastore.Entity, retry=None, transaction: Transaction | None = None): | |
return await self.async_multi_put([entity], retry, transaction) | |
@tracer.wrap(name="datastore", resource="get") | |
async def async_multi_get( | |
self, keys: Iterable[datastore.Key], retry=None, | |
transaction: Transaction | None = None, | |
is_warmup: bool = False) -> list[datastore.Entity]: | |
if is_warmup: | |
if span := tracer.current_span(): | |
span.set_tag("warmup", "true") | |
pb_keys = [key.to_protobuf() for key in keys] | |
read_options = types.ReadOptions(transaction=transaction.transaction_id) if transaction else None | |
lookup_request = types.LookupRequest(project_id=self.project_id, keys=pb_keys, read_options=read_options) | |
response = await self.client.lookup(lookup_request, retry=retry or self.default_read_retry) | |
return [entity_from_protobuf(found.entity) for found in response.found] | |
@tracer.wrap(name="datastore", resource="put") | |
async def async_multi_put(self, entities: Iterable[datastore.Entity], retry=None, | |
transaction: Transaction | None = None) -> datastore_v1.types.CommitResponse: | |
""" | |
Asynchronously put multiple entities into Datastore. | |
""" | |
mutations = [types.Mutation(upsert=entity_to_protobuf(entity)) for entity in entities] | |
commit_request = types.CommitRequest( | |
project_id=self.project_id, | |
mode=types.CommitRequest.Mode.TRANSACTIONAL if transaction else types.CommitRequest.Mode.NON_TRANSACTIONAL, | |
mutations=mutations, | |
transaction=transaction.transaction_id if transaction else None | |
) | |
resp = await self.client.commit(commit_request, retry=retry or self.default_write_retry) | |
if transaction: | |
transaction.closed = True | |
return resp | |
async def async_delete(self, key, retry=None, transaction: Transaction | None = None): | |
await self.async_multi_delete([key], retry=retry, transaction=transaction) | |
@tracer.wrap(name="datastore", resource="delete") | |
async def async_multi_delete(self, keys, retry=None, | |
transaction: Transaction | None = None) -> datastore_v1.types.CommitResponse: | |
""" | |
Asynchronously delete multiple entities from Datastore. | |
""" | |
mutations = [types.Mutation(delete=key.to_protobuf()) for key in keys] | |
commit_request = types.CommitRequest( | |
project_id=self.project_id, | |
mode=types.CommitRequest.Mode.TRANSACTIONAL if transaction else types.CommitRequest.Mode.NON_TRANSACTIONAL, | |
mutations=mutations, | |
transaction=transaction.transaction_id if transaction else None | |
) | |
resp = await self.client.commit(commit_request, retry=retry or self.default_write_retry) | |
if transaction: | |
transaction.closed = True | |
return resp | |
async def async_iter_entities( | |
self, kind, retry=None, transaction: Transaction | None = None) -> AsyncIterator[datastore.Entity]: | |
""" | |
Asynchronously iterate over Datastore entities of a given kind. | |
""" | |
last_sleep = asyncio.get_event_loop().time() | |
query = types.Query(kind=[types.KindExpression(name=kind)]) | |
run_query_request = types.RunQueryRequest( | |
project_id=self.project_id, | |
query=query, | |
read_options=types.ReadOptions(transaction=transaction.transaction_id if transaction else None) | |
) | |
more_results = True | |
while more_results: | |
response = await self.client.run_query(run_query_request, retry=retry or self.default_read_retry) | |
for entity_result in response.batch.entity_results: | |
if asyncio.get_event_loop().time() - last_sleep >= 0.01: | |
await asyncio.sleep(0) # yield to other tasks every 10 ms | |
last_sleep = asyncio.get_event_loop().time() | |
yield entity_from_protobuf(entity_result.entity) | |
more_results = response.batch.more_results in ( | |
types.QueryResultBatch.MoreResultsType.NOT_FINISHED, | |
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_LIMIT, | |
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_CURSOR, | |
) | |
if more_results: | |
run_query_request.query.start_cursor = response.batch.end_cursor | |
async def async_iter_keys(self, kind, retry=None) -> AsyncIterator[str]: | |
""" | |
Asynchronously iterate over Datastore keys of a given kind. | |
""" | |
last_sleep = asyncio.get_event_loop().time() | |
query = types.Query( | |
kind=[types.KindExpression(name=kind)], | |
projection=[types.Projection(property=types.PropertyReference(name="__key__"))] | |
) | |
run_query_request = types.RunQueryRequest(project_id=self.project_id, query=query) | |
more_results = True | |
while more_results: | |
response = await self.client.run_query(run_query_request, retry=retry or self.default_read_retry) | |
for entity_result in response.batch.entity_results: | |
if asyncio.get_event_loop().time() - last_sleep >= 0.01: | |
await asyncio.sleep(0) # yield to other tasks every 10 ms | |
last_sleep = asyncio.get_event_loop().time() | |
yield entity_result.entity.key.path[-1].name | |
more_results = response.batch.more_results in ( | |
types.QueryResultBatch.MoreResultsType.NOT_FINISHED, | |
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_LIMIT, | |
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_CURSOR, | |
) | |
if more_results: | |
run_query_request.query.start_cursor = response.batch.end_cursor | |
def get_client() -> AsyncDatastore: | |
return AsyncDatastore.get_client() | |
@tracer.wrap(name="datastore", resource="transaction") | |
def transaction_with_retries(max_attempts=3, begin_transaction_retry=None): | |
""" | |
This is a transaction decorator that can be used to wrap a function that needs to be executed within a | |
Datastore transaction. The name of the transaction parameter must be 'transaction'. | |
This currently only works with the default client, i.e. get_client(). | |
You can use a function wrapped with this decorator within a larger transaction. However, the transaction | |
ends on the first database write. i.e. you can use lots of gets but only one write. If you need multi-write | |
transactions then you need to implement it yourself. | |
""" | |
def decorator(func: Callable): | |
@wraps(func) | |
async def wrapper(*args, **kwargs): | |
if kwargs.get('transaction'): | |
t = kwargs['transaction'] | |
if t.closed: | |
raise ValueError('Transaction is closed') | |
return await func(*args, **kwargs) | |
for attempt in range(max_attempts): | |
try: | |
async with get_client()._transaction(retry=begin_transaction_retry) as transaction: | |
kwargs['transaction'] = transaction | |
return await func(*args, **kwargs) | |
except Aborted: | |
if attempt >= max_attempts - 1: | |
raise | |
await asyncio.sleep(0.050) | |
continue | |
except LeakyAbort as e: | |
raise e.to_aborted() | |
return wrapper | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment