Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save daniel-kristjansson/ff435e2bef90e9d86b3f82ca896d393d to your computer and use it in GitHub Desktop.
Save daniel-kristjansson/ff435e2bef90e9d86b3f82ca896d393d to your computer and use it in GitHub Desktop.
async datastore usage
import asyncio
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import wraps
from typing import AsyncIterator, Union, Iterable, Callable
from ddtrace import tracer
from google.api_core.exceptions import Aborted
from google.api_core.retry import AsyncRetry
from google.auth import default as auth_default
from google.auth.exceptions import RefreshError, DefaultCredentialsError
from google.cloud import datastore, datastore_v1
from google.cloud.datastore.helpers import entity_from_protobuf, entity_to_protobuf
from google.cloud.datastore_v1 import types
from app.datastore.exceptions import DatastoreConnectionError, LeakyAbort
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG to capture all levels of logs
logger = logging.getLogger(__name__)
DATASTORE_DEFAULT_READ_ASYNC_RETRY = AsyncRetry(
initial=0.1, # Start with a 100 ms delay
maximum=1.0, # Don't wait more than 1 second between retries
multiplier=2.0, # Double the wait each retry
deadline=5.0, # Give up after 5 seconds total
)
DATASTORE_DEFAULT_WRITE_ASYNC_RETRY = AsyncRetry(
initial=0.1, # Start with a 100 ms delay
maximum=1.0, # Don't wait more than 1 second between retries
multiplier=2.0, # Double the wait each retry
deadline=5.0, # Give up after 5 seconds total
)
@dataclass
class Transaction:
transaction_id: bytes
closed: bool = False
class AsyncDatastore:
_default_client: Union['AsyncDatastore', None] = None
@classmethod
def get_client(cls) -> 'AsyncDatastore':
"""
Get the default datastore client.
:return: AsyncDatastore instance using default credentials and project id
:raise: DatastoreConnectionError if the client cannot be initialized
"""
if cls._default_client:
return cls._default_client
try:
datastore_credentials, datastore_project_id = auth_default()
base_client = datastore_v1.DatastoreAsyncClient(credentials=datastore_credentials)
cls._default_client = AsyncDatastore(base_client, datastore_project_id)
except (RefreshError, DefaultCredentialsError) as e:
logger.error("Failed to initialize datastore client %s", repr(e))
raise DatastoreConnectionError()
return cls._default_client
@classmethod
async def close(cls):
if cls._default_client:
client = cls._default_client.client
cls._default_client = None
if client.transport:
await client.transport.close()
def __init__(self, client: datastore_v1.DatastoreAsyncClient,
project_id: str,
default_read_retry: AsyncRetry = DATASTORE_DEFAULT_READ_ASYNC_RETRY,
default_write_retry: AsyncRetry = DATASTORE_DEFAULT_WRITE_ASYNC_RETRY):
self.client = client
self.project_id = project_id
self.default_read_retry = default_read_retry
self.default_write_retry = default_write_retry
def key(self, *path_args, **kwargs):
kwargs["project"] = self.project_id
return datastore.Key(*path_args, **kwargs)
@asynccontextmanager
async def _transaction(self, retry=None):
"""
Don't use this directly.
Use the transaction_with_retries decorator instead as it will handle retrying, and unclosed transactions.
"""
request = types.datastore.BeginTransactionRequest(project_id=self.project_id)
response = await self.client.begin_transaction(request=request, retry=retry or self.default_read_retry)
transaction = Transaction(transaction_id=response.transaction)
try:
yield transaction
finally:
if not transaction.closed:
await self.async_rollback(transaction)
@tracer.wrap(name="datastore", resource="rollback")
async def async_rollback(self, transaction: Transaction, retry=None):
request = types.RollbackRequest(project_id=self.project_id, transaction=transaction.transaction_id)
resp = await self.client.rollback(request=request, retry=retry or self.default_write_retry)
transaction.closed = True
return resp
async def async_get(self, key, retry=None, transaction: Transaction | None = None) -> datastore.Entity | None:
entities = await self.async_multi_get([key], retry, transaction)
return entities[0] if entities and len(entities) > 0 else None
async def async_put(self, entity: datastore.Entity, retry=None, transaction: Transaction | None = None):
return await self.async_multi_put([entity], retry, transaction)
@tracer.wrap(name="datastore", resource="get")
async def async_multi_get(
self, keys: Iterable[datastore.Key], retry=None,
transaction: Transaction | None = None,
is_warmup: bool = False) -> list[datastore.Entity]:
if is_warmup:
if span := tracer.current_span():
span.set_tag("warmup", "true")
pb_keys = [key.to_protobuf() for key in keys]
read_options = types.ReadOptions(transaction=transaction.transaction_id) if transaction else None
lookup_request = types.LookupRequest(project_id=self.project_id, keys=pb_keys, read_options=read_options)
response = await self.client.lookup(lookup_request, retry=retry or self.default_read_retry)
return [entity_from_protobuf(found.entity) for found in response.found]
@tracer.wrap(name="datastore", resource="put")
async def async_multi_put(self, entities: Iterable[datastore.Entity], retry=None,
transaction: Transaction | None = None) -> datastore_v1.types.CommitResponse:
"""
Asynchronously put multiple entities into Datastore.
"""
mutations = [types.Mutation(upsert=entity_to_protobuf(entity)) for entity in entities]
commit_request = types.CommitRequest(
project_id=self.project_id,
mode=types.CommitRequest.Mode.TRANSACTIONAL if transaction else types.CommitRequest.Mode.NON_TRANSACTIONAL,
mutations=mutations,
transaction=transaction.transaction_id if transaction else None
)
resp = await self.client.commit(commit_request, retry=retry or self.default_write_retry)
if transaction:
transaction.closed = True
return resp
async def async_delete(self, key, retry=None, transaction: Transaction | None = None):
await self.async_multi_delete([key], retry=retry, transaction=transaction)
@tracer.wrap(name="datastore", resource="delete")
async def async_multi_delete(self, keys, retry=None,
transaction: Transaction | None = None) -> datastore_v1.types.CommitResponse:
"""
Asynchronously delete multiple entities from Datastore.
"""
mutations = [types.Mutation(delete=key.to_protobuf()) for key in keys]
commit_request = types.CommitRequest(
project_id=self.project_id,
mode=types.CommitRequest.Mode.TRANSACTIONAL if transaction else types.CommitRequest.Mode.NON_TRANSACTIONAL,
mutations=mutations,
transaction=transaction.transaction_id if transaction else None
)
resp = await self.client.commit(commit_request, retry=retry or self.default_write_retry)
if transaction:
transaction.closed = True
return resp
async def async_iter_entities(
self, kind, retry=None, transaction: Transaction | None = None) -> AsyncIterator[datastore.Entity]:
"""
Asynchronously iterate over Datastore entities of a given kind.
"""
last_sleep = asyncio.get_event_loop().time()
query = types.Query(kind=[types.KindExpression(name=kind)])
run_query_request = types.RunQueryRequest(
project_id=self.project_id,
query=query,
read_options=types.ReadOptions(transaction=transaction.transaction_id if transaction else None)
)
more_results = True
while more_results:
response = await self.client.run_query(run_query_request, retry=retry or self.default_read_retry)
for entity_result in response.batch.entity_results:
if asyncio.get_event_loop().time() - last_sleep >= 0.01:
await asyncio.sleep(0) # yield to other tasks every 10 ms
last_sleep = asyncio.get_event_loop().time()
yield entity_from_protobuf(entity_result.entity)
more_results = response.batch.more_results in (
types.QueryResultBatch.MoreResultsType.NOT_FINISHED,
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_LIMIT,
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_CURSOR,
)
if more_results:
run_query_request.query.start_cursor = response.batch.end_cursor
async def async_iter_keys(self, kind, retry=None) -> AsyncIterator[str]:
"""
Asynchronously iterate over Datastore keys of a given kind.
"""
last_sleep = asyncio.get_event_loop().time()
query = types.Query(
kind=[types.KindExpression(name=kind)],
projection=[types.Projection(property=types.PropertyReference(name="__key__"))]
)
run_query_request = types.RunQueryRequest(project_id=self.project_id, query=query)
more_results = True
while more_results:
response = await self.client.run_query(run_query_request, retry=retry or self.default_read_retry)
for entity_result in response.batch.entity_results:
if asyncio.get_event_loop().time() - last_sleep >= 0.01:
await asyncio.sleep(0) # yield to other tasks every 10 ms
last_sleep = asyncio.get_event_loop().time()
yield entity_result.entity.key.path[-1].name
more_results = response.batch.more_results in (
types.QueryResultBatch.MoreResultsType.NOT_FINISHED,
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_LIMIT,
types.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_CURSOR,
)
if more_results:
run_query_request.query.start_cursor = response.batch.end_cursor
def get_client() -> AsyncDatastore:
return AsyncDatastore.get_client()
@tracer.wrap(name="datastore", resource="transaction")
def transaction_with_retries(max_attempts=3, begin_transaction_retry=None):
"""
This is a transaction decorator that can be used to wrap a function that needs to be executed within a
Datastore transaction. The name of the transaction parameter must be 'transaction'.
This currently only works with the default client, i.e. get_client().
You can use a function wrapped with this decorator within a larger transaction. However, the transaction
ends on the first database write. i.e. you can use lots of gets but only one write. If you need multi-write
transactions then you need to implement it yourself.
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
if kwargs.get('transaction'):
t = kwargs['transaction']
if t.closed:
raise ValueError('Transaction is closed')
return await func(*args, **kwargs)
for attempt in range(max_attempts):
try:
async with get_client()._transaction(retry=begin_transaction_retry) as transaction:
kwargs['transaction'] = transaction
return await func(*args, **kwargs)
except Aborted:
if attempt >= max_attempts - 1:
raise
await asyncio.sleep(0.050)
continue
except LeakyAbort as e:
raise e.to_aborted()
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment