Skip to content

Instantly share code, notes, and snippets.

@sawa2d2
Last active February 25, 2026 00:30
Show Gist options
  • Select an option

  • Save sawa2d2/a0b4215e05c856bddba8f5d3c34428fb to your computer and use it in GitHub Desktop.

Select an option

Save sawa2d2/a0b4215e05c856bddba8f5d3c34428fb to your computer and use it in GitHub Desktop.
from dataclasses import fields, MISSING
import logging
logger = logging.getLogger(__name__)
def from_dict(cls, data: dict):
field_map = {field.name: field for field in fields(cls)}
extra_keys = data.keys() - field_map.keys()
if extra_keys:
logger.warning(f"Extra keys ignored: {extra_keys}")
missing_required = {
name
for name, field in field_map.items()
if name not in data
and field.default is MISSING
and field.default_factory is MISSING
}
if missing_required:
logger.error(f"Missing required fields: {missing_required}")
return cls(**{
name: data[name]
for name in field_map.keys() & data.keys()
})
import os
import time
import sqlite3
import tempfile
import multiprocessing as mp
from typing import List
from queue_impl import SqliteQueue, run_batch_consumer # ← 実装ファイル名に合わせる
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def init_db(db_path: str, n_jobs: int) -> None:
q = SqliteQueue(db_path)
conn = q.connect()
q.init_schema(conn)
for i in range(n_jobs):
q.enqueue(conn, f"job-{i}")
conn.close()
def collect_done_payloads(db_path: str) -> List[str]:
conn = sqlite3.connect(db_path)
rows = conn.execute(
"SELECT payload FROM queue WHERE status='done'"
).fetchall()
conn.close()
return sorted(r[0] for r in rows)
# ----------------------------------------------------------------------
# Test 1: single consumer processes all jobs
# ----------------------------------------------------------------------
def test_single_consumer():
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "queue.db")
init_db(db_path, n_jobs=10)
def process(payload: str):
pass
run_batch_consumer(db_path, process, max_jobs=100)
done = collect_done_payloads(db_path)
assert done == [f"job-{i}" for i in range(10)]
# ----------------------------------------------------------------------
# Test 2: multiple consumers concurrently (no double consume)
# ----------------------------------------------------------------------
def _run_consumer(db_path: str):
def process(payload: str):
time.sleep(0.01) # simulate work
run_batch_consumer(db_path, process, max_jobs=100)
def test_concurrent_consumers():
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "queue.db")
init_db(db_path, n_jobs=50)
procs = [
mp.Process(target=_run_consumer, args=(db_path,))
for _ in range(4)
]
for p in procs:
p.start()
for p in procs:
p.join()
done = collect_done_payloads(db_path)
# All jobs processed exactly once
assert len(done) == 50
assert done == [f"job-{i}" for i in range(50)]
# ----------------------------------------------------------------------
# Test 3: crash during processing + reclaim
# ----------------------------------------------------------------------
def test_reclaim_after_crash():
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "queue.db")
init_db(db_path, n_jobs=1)
q = SqliteQueue(db_path)
conn = q.connect()
# Manually claim a job (simulate crash before ack)
job = q.claim_one(conn)
assert job is not None
# Do NOT ack or fail (simulate process crash)
conn.close()
# Reclaim immediately
q = SqliteQueue(db_path)
conn = q.connect()
reclaimed = q.reclaim_stale(conn, timeout_seconds=0)
assert reclaimed == 1
# Now it should be processable again
processed = []
def process(payload: str):
processed.append(payload)
run_batch_consumer(db_path, process, max_jobs=1)
assert processed == ["job-0"]
# ----------------------------------------------------------------------
# Test 4: stale ack is rejected by lease fencing
# ----------------------------------------------------------------------
def test_stale_ack_is_rejected():
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "queue.db")
init_db(db_path, n_jobs=1)
q = SqliteQueue(db_path)
conn1 = q.connect()
conn2 = q.connect()
# Consumer A claims
job_a = q.claim_one(conn1)
assert job_a is not None
# Consumer B reclaims + claims
q.reclaim_stale(conn2, timeout_seconds=0)
job_b = q.claim_one(conn2)
assert job_b is not None
assert job_b.lease > job_a.lease
# A tries to ack with stale lease
ok = q.ack(conn1, job_a)
assert ok is False
# B can ack successfully
ok = q.ack(conn2, job_b)
assert ok is True
conn1.close()
conn2.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment