Skip to content

Instantly share code, notes, and snippets.

@h0rn3t
Forked from horodchukanton/simple_tasks_queue.py
Created August 1, 2022 10:03
Show Gist options
  • Save h0rn3t/dfd7a5f89249447e443ab5b57a971763 to your computer and use it in GitHub Desktop.
Save h0rn3t/dfd7a5f89249447e443ab5b57a971763 to your computer and use it in GitHub Desktop.
FastAPI Background tasks queue
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta, datetime
from typing import Any, Dict
class SimpleStorageEntry:
_counter: int = 0
id: int
timestamp: datetime
ttl: int
data: Any
def __init__(self, data: Any, ttl: int = 86400):
SimpleStorageEntry._counter += 1
self.id = SimpleStorageEntry._counter
self.timestamp = datetime.now()
self.ttl = ttl
self.data = data
class SimpleStorage:
_map: Dict[int, SimpleStorageEntry]
def __init__(self):
self._map = {}
def set(self, data: Any, ttl: int = 60 * 60 * 24):
entry = SimpleStorageEntry(data=data, ttl=ttl)
self._map[entry.id] = entry
return entry.id
def get(self, task_token: int) -> Any:
self._check_ttls()
try:
entry = self._map[task_token]
return entry.data
except KeyError as ke:
raise MissingEntryRequestedException from ke
def update(self, task_token, new_data):
self._check_ttls()
try:
entry = self._map[task_token]
entry.data = new_data
except KeyError as ke:
raise MissingEntryRequestedException from ke
def _check_ttls(self):
keys = self._map.copy().keys()
for k in keys:
entry = self._map[k]
time_to_die = entry.timestamp + timedelta(seconds=entry.ttl)
if time_to_die < datetime.now():
del self._map[k]
class BackgroundExecutor:
executor = None
def __init__(self):
# pylint: disable=consider-using-with
self.executor = ThreadPoolExecutor(max_workers=1)
def add_task(self, fn):
self.executor.submit(fn)
class TaskFailedException(Exception):
caused_by: Exception
def __init__(self, caused_by, *args, **kwargs):
super().__init__(*args, **kwargs)
self.caused_by = caused_by
class NoOperationForTokenException(Exception):
pass
class RequestResultsForInProgressOperation(Exception):
pass
class MissingEntryRequestedException(Exception):
pass
class SimpleTasksQueue:
kv_storage: SimpleStorage = None
bg_tasks: BackgroundExecutor = None
def __init__(self):
self._pending_value = str(id(self)) + '_pending'
self._running_value = str(id(self)) + '_running'
self.kv_staorage = SimpleStorage()
self.bg_tasks = BackgroundExecutor()
def add(self, func, *args, **kwargs):
task_token = self.kv_storage.set(self._pending_value)
wait_task = self._wrap_task(task_token, func, *args, **kwargs)
self.bg_tasks.add_task(wait_task)
return task_token
def get(self, task_token) -> Any:
if not self.is_ready(task_token):
raise RequestResultsForInProgressOperation()
result = self.kv_storage.get(task_token)
if isinstance(result, TaskFailedException):
raise result.caused_by from result
return result
def is_ready(self, task_token):
return self.kv_storage.get(task_token) not in [
self._running_value, self._pending_value]
def _wrap_task(self, task_token, fn, *args, **kwargs):
def update_result_when_finished():
try:
logging.debug("Task %i started", task_token)
self.kv_storage.update(task_token, self._running_value)
result = fn(*args, **kwargs)
logging.debug("Task %i finished", task_token)
except Exception as e:
logging.exception("Task %i failed", task_token, exc_info=e)
result = TaskFailedException(caused_by=e)
logging.debug("Saving result for %i", task_token)
self.kv_storage.update(task_token, result)
return update_result_when_finished
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment