-
-
Save h0rn3t/dfd7a5f89249447e443ab5b57a971763 to your computer and use it in GitHub Desktop.
FastAPI Background tasks queue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from concurrent.futures import ThreadPoolExecutor | |
from datetime import timedelta, datetime | |
from typing import Any, Dict | |
class SimpleStorageEntry: | |
_counter: int = 0 | |
id: int | |
timestamp: datetime | |
ttl: int | |
data: Any | |
def __init__(self, data: Any, ttl: int = 86400): | |
SimpleStorageEntry._counter += 1 | |
self.id = SimpleStorageEntry._counter | |
self.timestamp = datetime.now() | |
self.ttl = ttl | |
self.data = data | |
class SimpleStorage: | |
_map: Dict[int, SimpleStorageEntry] | |
def __init__(self): | |
self._map = {} | |
def set(self, data: Any, ttl: int = 60 * 60 * 24): | |
entry = SimpleStorageEntry(data=data, ttl=ttl) | |
self._map[entry.id] = entry | |
return entry.id | |
def get(self, task_token: int) -> Any: | |
self._check_ttls() | |
try: | |
entry = self._map[task_token] | |
return entry.data | |
except KeyError as ke: | |
raise MissingEntryRequestedException from ke | |
def update(self, task_token, new_data): | |
self._check_ttls() | |
try: | |
entry = self._map[task_token] | |
entry.data = new_data | |
except KeyError as ke: | |
raise MissingEntryRequestedException from ke | |
def _check_ttls(self): | |
keys = self._map.copy().keys() | |
for k in keys: | |
entry = self._map[k] | |
time_to_die = entry.timestamp + timedelta(seconds=entry.ttl) | |
if time_to_die < datetime.now(): | |
del self._map[k] | |
class BackgroundExecutor: | |
executor = None | |
def __init__(self): | |
# pylint: disable=consider-using-with | |
self.executor = ThreadPoolExecutor(max_workers=1) | |
def add_task(self, fn): | |
self.executor.submit(fn) | |
class TaskFailedException(Exception): | |
caused_by: Exception | |
def __init__(self, caused_by, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.caused_by = caused_by | |
class NoOperationForTokenException(Exception): | |
pass | |
class RequestResultsForInProgressOperation(Exception): | |
pass | |
class MissingEntryRequestedException(Exception): | |
pass | |
class SimpleTasksQueue: | |
kv_storage: SimpleStorage = None | |
bg_tasks: BackgroundExecutor = None | |
def __init__(self): | |
self._pending_value = str(id(self)) + '_pending' | |
self._running_value = str(id(self)) + '_running' | |
self.kv_staorage = SimpleStorage() | |
self.bg_tasks = BackgroundExecutor() | |
def add(self, func, *args, **kwargs): | |
task_token = self.kv_storage.set(self._pending_value) | |
wait_task = self._wrap_task(task_token, func, *args, **kwargs) | |
self.bg_tasks.add_task(wait_task) | |
return task_token | |
def get(self, task_token) -> Any: | |
if not self.is_ready(task_token): | |
raise RequestResultsForInProgressOperation() | |
result = self.kv_storage.get(task_token) | |
if isinstance(result, TaskFailedException): | |
raise result.caused_by from result | |
return result | |
def is_ready(self, task_token): | |
return self.kv_storage.get(task_token) not in [ | |
self._running_value, self._pending_value] | |
def _wrap_task(self, task_token, fn, *args, **kwargs): | |
def update_result_when_finished(): | |
try: | |
logging.debug("Task %i started", task_token) | |
self.kv_storage.update(task_token, self._running_value) | |
result = fn(*args, **kwargs) | |
logging.debug("Task %i finished", task_token) | |
except Exception as e: | |
logging.exception("Task %i failed", task_token, exc_info=e) | |
result = TaskFailedException(caused_by=e) | |
logging.debug("Saving result for %i", task_token) | |
self.kv_storage.update(task_token, result) | |
return update_result_when_finished |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment