Last active
April 4, 2024 15:04
-
-
Save FinnWoelm/6d243fd0e65b008abccc5c009cfdbc49 to your computer and use it in GitHub Desktop.
PostgreSQL Task Queue (with SQLAlchemy)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from sqlalchemy import Column, Integer, String, DateTime | |
from sqlalchemy import not_, exists, alias, select | |
from sqlalchemy.ext.hybrid import hybrid_property | |
import models | |
class Task(models.BaseModel): | |
id = Column(Integer, primary_key=True) | |
# Domain name of the website | |
domain = Column(String, nullable=False) | |
# Name of the pipeline to run (e.g. scrape, extract, analyse, classify, ...) | |
pipeline = Column(String, nullable=False) | |
# Status: queued, completed or failed | |
# We don't have status "running" because we cannot update | |
# the record while it is locked for update | |
status = Column(String, nullable=False, index=True) | |
# Arbitrary lock name | |
# All tasks with the same lock name will run in sequential order | |
# Task with lock "ABC" and ID 1 will run before task with lock "ABC" and ID 5 | |
# Tasks with different locks can run in parallel | |
lock = Column(String, nullable=True, index=True) | |
# Name of the worker that worked on this task | |
worker = Column(String, nullable=True) | |
# Start and end timestamp | |
started_at = Column(DateTime, nullable=True) | |
ended_at = Column(DateTime, nullable=True) | |
# Optional column for logging errors that occur during task execution | |
error = Column(String, nullable=True) | |
STATUS_QUEUED = "queued" | |
STATUS_COMPLETED = "completed" | |
STATUS_FAILED = "failed" | |
@hybrid_property | |
def is_queued(self): | |
return self.status == self.STATUS_QUEUED | |
@hybrid_property | |
def is_completed(self): | |
return self.status == self.STATUS_COMPLETED | |
@hybrid_property | |
def is_failed(self): | |
return self.status == self.STATUS_FAILED | |
@classmethod | |
def queue(cls, **attributes): | |
return cls(**attributes, status=cls.STATUS_QUEUED) | |
# Get the next task available for processing and lock it for the duration of | |
# the transaction. This picks one task that is ready for processing, i.e., | |
# there is no dependency on another task or the dependency has been | |
# fulfilled because the task was completed. | |
# Inspired by Procrastinate: https://github.com/procrastinate-org/procrastinate/blob/30e06712c36597e5761c60dbc3161364d9049c17/procrastinate/sql/schema.sql#L148 | |
@classmethod | |
def get_next(cls): | |
preceding_task = alias(cls.__table__, "precedingtask") | |
return ( | |
cls.select() | |
# Criteria: Status of this task is 'queued' | |
.where(cls.is_queued) | |
# Criteria: Task has no lock or all preceding tasks with lock have | |
# been completed | |
.where( | |
not_( | |
exists( | |
select(1) | |
.where(cls.lock != None) | |
.where(preceding_task.c.id < cls.id) | |
.where(preceding_task.c.lock == cls.lock) | |
.where(preceding_task.c.status != Task.STATUS_COMPLETED) | |
) | |
) | |
) | |
.order_by(cls.id) | |
.limit(1) | |
.with_for_update(skip_locked=True, of=cls) | |
) | |
def mark_running(self): | |
self.started_at = datetime.datetime.utcnow() | |
def mark_completed(self): | |
self.status = self.STATUS_COMPLETED | |
self.ended_at = datetime.datetime.utcnow() | |
self.error = None | |
def mark_failed(self, error=None): | |
self.status = self.STATUS_FAILED | |
self.ended_at = datetime.datetime.utcnow() | |
self.error = error |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import time | |
import socket | |
import traceback | |
from models import Task | |
from lib.database import DatabaseSession | |
# Load config: We do not need it for the worker, but this makes sure that the | |
# config is valid. If the config is invalid, the worker will crash right away. | |
print("Loading config", "...") | |
from config import config | |
# Check if should work | |
if not config.PERFORM_WORK: | |
print("PERFORM_WORK=FALSE: Not starting worker") | |
while True: | |
time.sleep(1) | |
# Get the Docker container ID of this worker | |
DOCKER_CONTAINER_ID = socket.gethostname() | |
print("Starting worker", DOCKER_CONTAINER_ID, "...") | |
# Indicate whether the worker is currently waiting for a task | |
is_waiting_for_task = False | |
while True: | |
with DatabaseSession() as db: | |
# Get the next available task | |
task = db.scalar(Task.get_next()) | |
# If no task was found, wait for a short time and then try again | |
if not task: | |
if not is_waiting_for_task: | |
is_waiting_for_task = True | |
print("Waiting for task") | |
time.sleep(1) | |
continue | |
# Task was found, begin execution | |
print("Acquired task:", f"#{task.id} ({task.domain}, {task.pipeline})") | |
is_waiting_for_task = False | |
# Mark the task as running, store the worker name (container ID), and | |
# reset error state | |
task.mark_running() | |
task.worker = DOCKER_CONTAINER_ID | |
# Establish a savepoint, so that we can roll back any changes made by | |
# the pipeline, if it fails. We do not want to rollback the entire | |
# transaction, otherwise we lose the lock on the task. | |
savepoint = db.begin_nested() | |
try: | |
# Perform work | |
# Do something with the job here, such as running the actual | |
# pipeline, etc... | |
print(task) | |
# Mark task as completed | |
task.mark_completed() | |
except Exception as exc: | |
exc_trace = traceback.format_exc() | |
# Rollback any changes made by the pipeline | |
savepoint.rollback() | |
# Mark task as failed | |
task.mark_failed(error=exc_trace) | |
# Print error | |
print(exc_trace) | |
# Commit changes | |
db.commit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment