Skip to content

Instantly share code, notes, and snippets.

@wpcarro
Created December 18, 2024 17:20
Show Gist options
  • Save wpcarro/895bd8fb07e1bcde6c5d3d1df50f498c to your computer and use it in GitHub Desktop.
Save wpcarro/895bd8fb07e1bcde6c5d3d1df50f498c to your computer and use it in GitHub Desktop.
More durability experiments: read state from disk at startup with data integrity checks
import signal
import time
import random
import json
import os
import threading
import hashlib
lock = threading.Lock()
# Refresh the state at startup and check its integrity. Lots of chaos
# possible, so we need to handle this robustly:
# - File not found
# - JSON parsing issues
# - Tampered at rest
state = None
try:
with open("/tmp/dump.json", "r") as f:
print("Reading state from disk...")
content = f.read()
parsed = json.loads(content)
state = parsed["state"]
checksum = hashlib.sha256(json.dumps(state).encode("utf-8")).hexdigest()
assert parsed["checksum"] == checksum, "Data integrity check failed"
except Exception as e:
print(f"Failed to read state from disk: {e}")
with lock:
state = [1, 2, 3, 4, 5]
assert state is not None
def persist_state():
print("Persisting state...")
tmp = "/tmp/buffer.json"
dst = "/tmp/dump.json"
with lock:
checksum = hashlib.sha256(json.dumps(state).encode("utf-8")).hexdigest()
content = json.dumps({ "state": state, "checksum": checksum })
with open(tmp, "w") as f:
f.write(content)
f.flush()
os.fsync(f)
os.rename(tmp, dst)
def handle_signal(signum, _frame):
print(f"Received signal: {signum}")
match signum:
case signal.SIGINT:
persist_state()
# Restore the default behavior and re-signal
signal.signal(signal.SIGINT, signal.SIG_DFL)
raise KeyboardInterrupt
case signal.SIGTERM:
persist_state()
# Restore the default behavior and re-signal
signal.signal(signal.SIGTERM, signal.SIG_DFL)
raise SystemExit
case signal.SIGHUP:
persist_state()
# Restore the default behavior and re-signal
signal.signal(signal.SIGHUP, signal.SIG_DFL)
raise SystemExit
case x:
logger.error(f"Unhandled signal: {x}")
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGHUP, handle_signal)
def checkpoint():
sleep = 5
minute = 60 / sleep
for _ in range(int(60 * minute)):
persist_state()
print("Sleeping...")
time.sleep(sleep)
# Start background thread that checkpoints state
t = threading.Thread(target=lambda: checkpoint(), daemon=True)
t.start()
sleep = 5
minute = 60 / sleep
for _ in range(int(60 * minute)):
print("Awaiting signal...")
if random.choice([True, False]):
with lock:
state.append(random.choice(range(100)))
time.sleep(sleep)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment