Skip to content

Instantly share code, notes, and snippets.

@aalhour
Last active December 23, 2025 15:15
Show Gist options
  • Select an option

  • Save aalhour/97c4bf2c7c273a4d83c77489c2507547 to your computer and use it in GitHub Desktop.

Select an option

Save aalhour/97c4bf2c7c273a4d83c77489c2507547 to your computer and use it in GitHub Desktop.
PyRocks: Database Workshop Example Code

PyRocks

A RocksDB-inspired POC for educational purposes only, used in a workshop to teach databases and demonstrate the concepts of LSM-based databases: SSTables, WAL, Memtables and Bloomfilters.

#!/usr/bin/env python3
"""
Teaching Tips for the Workshop:
1. The "Cat" Test: While the script is running (specifically after
Scenario 2), pause and have the students open the wal.log file
in a text editor. They will see the JSON lines. This demystifies
the log.
2. The "Hex" Test: After Scenario 3, have them use a hex editor (or
hexdump -C on Mac/Linux) to view an .sst file. They will see the
raw strings "Alice", "Bob" interspersed with binary length headers.
3. The "Delete" Trick: Ask them "Where did 'user2' go?" after Scenario 4.
If they look at the SSTable or WAL, they will see user2 is still
there, but with a __DELETED__ marker. This explains why deletes take
up disk space initially.
"""
import shutil
import os
import time
from pyrocks import PyRocks
DB_FOLDER = "./my_workshop_db"
def clean_start():
"""Helper to wipe the DB folder for a fresh test"""
if os.path.exists(DB_FOLDER):
shutil.rmtree(DB_FOLDER)
print(f"\n--- Starting Fresh in {DB_FOLDER} ---")
def print_separator(title):
print(f"\n{'='*10} {title} {'='*10}")
# ===============================================
# Scenario 1: Basic Put and Get (MemTable logic)
# ===============================================
clean_start()
print_separator("Scenario 1: In-Memory Operations")
db = PyRocks(DB_FOLDER)
print("Putting keys 'user1' and 'user2'...")
db.put("user1", "Alice")
db.put("user2", "Bob")
print(f"Get 'user1': {db.get('user1')}")
print(f"Get 'user2': {db.get('user2')}")
print(f"Get 'user3' (non-existent): {db.get('user3')}")
# ===============================================
# Scenario 2: Persistence & Recovery (WAL logic)
# ===============================================
print_separator("Scenario 2: Crash Recovery (WAL)")
print("Simulating a crash (closing DB without flushing)...")
# Note: We haven't hit the threshold (5), so data is only in MemTable and WAL.
del db
print("Restarting Database...")
db_recovered = PyRocks(DB_FOLDER)
print("Reading data after recovery:")
print(f"Get 'user1': {db_recovered.get('user1')}") # Should exist due to WAL
# ===============================================
# Scenario 3: Flushing to SSTables
# ===============================================
print_separator("Scenario 3: Flushing to Disk (SSTables)")
# We trigger a flush by adding enough keys to hit the threshold (set to 5 in pyrocks.py)
print("Adding more keys to trigger flush...")
db_recovered.put("user3", "Charlie")
db_recovered.put("user4", "Dave")
db_recovered.put("user5", "Eve")
# Threshold hit! Check console output for "[Flush]..."
print("Checking disk for .sst files...")
sst_files = os.listdir(DB_FOLDER)
print("Files in DB folder:", sst_files)
# ===============================================
# Scenario 4: Updates and Tombstones
# ===============================================
print_separator("Scenario 4: Updates & Deletes")
print("Updating 'user1' to 'Alice_Updated'...")
db_recovered.put("user1", "Alice_Updated")
print("Deleting 'user2'...")
db_recovered.delete("user2")
print(f"Get 'user1': {db_recovered.get('user1')}") # Should be new value
print(f"Get 'user2': {db_recovered.get('user2')}") # Should be None
# ===============================================
# Scenario 5: Compaction
# ===============================================
print_separator("Scenario 5: Compaction")
# Force another flush to ensure we have multiple SSTables
db_recovered.put("user6", "Frank")
db_recovered.put("user7", "Grace")
db_recovered.put("user8", "Heidi")
db_recovered.put("user9", "Ivan")
db_recovered.put("user10", "Judy") # Flush 2 triggers here
print("Files before compaction:", [f for f in os.listdir(DB_FOLDER) if f.endswith('.sst')])
print("Running Compaction...")
db_recovered.compact()
print("Files after compaction:", [f for f in os.listdir(DB_FOLDER) if f.endswith('.sst')])
# Verify data still exists after compaction
print(f"Get 'user1' (updated): {db_recovered.get('user1')}")
print(f"Get 'user2' (deleted): {db_recovered.get('user2')}")
print(f"Get 'user10' (new): {db_recovered.get('user10')}")
print("\nWorkshop Demo Complete!")
#!/usr/bin/env python3
import os
import glob
import json
import struct
import time
import hashlib
from typing import Optional, List, Dict, Tuple
# ==========================================
# Constants & Configuration
# ==========================================
TOMBSTONE = "__DELETED__" # Marker for deleted keys
MEMTABLE_THRESHOLD = 5 # Max keys in memory before flushing to disk (kept low for demo)
class BloomFilter:
"""
A probabilistic data structure to test if an element is NOT in a set.
'90% Truth': Real DBs use MurmurHash and bit-arrays. We use MD5 and a simple integer list
to make the logic readable for students.
"""
def __init__(self, size=100, hash_count=3):
self.size = size
self.hash_count = hash_count
self.bit_array = [0] * size
def add(self, string):
for seed in range(self.hash_count):
result = int(hashlib.md5((string + str(seed)).encode()).hexdigest(), 16)
self.bit_array[result % self.size] = 1
def might_contain(self, string):
for seed in range(self.hash_count):
result = int(hashlib.md5((string + str(seed)).encode()).hexdigest(), 16)
if self.bit_array[result % self.size] == 0:
return False
return True
def to_json(self):
return json.dumps(self.bit_array)
@staticmethod
def from_json(json_str):
bf = BloomFilter()
bf.bit_array = json.loads(json_str)
return bf
class PyRocks:
"""
A simplified RocksDB clone.
Architecture:
1. WAL (Write Ahead Log): Append-only file for durability.
2. MemTable: In-memory sorted dictionary.
3. SSTables: Immutable disk files (LSM-Tree levels).
"""
def __init__(self, db_path):
self.db_path = db_path
if not os.path.exists(db_path):
os.makedirs(db_path)
self.wal_path = os.path.join(db_path, "wal.log")
self.memtable = {} # The in-memory data
self.immutable_memtables = [] # Queue for flushing
# Recovery: Replay WAL to restore MemTable on startup
self._recover_from_wal()
def get(self, key: str) -> Optional[str]:
"""
Public API: Read a value.
Order: MemTable -> SSTables (Newest -> Oldest)
"""
# 1. Check MemTable (Fastest)
if key in self.memtable:
val = self.memtable[key]
if val == TOMBSTONE:
return None # Key was deleted
return val
# 2. Check SSTables on disk (Newest first)
sst_files = sorted(glob.glob(os.path.join(self.db_path, "*.sst")), reverse=True)
for sst_file in sst_files:
value = self._search_sstable(sst_file, key)
if value is not None:
if value == TOMBSTONE:
return None
return value
return None
def put(self, key: str, value: str):
"""
Public API: Writes a key-value pair.
1. Write to WAL (Durability)
2. Write to MemTable (Speed)
3. Check if MemTable needs flushing
"""
# 1. Write Ahead Log
self._append_to_wal(key, value)
# 2. Update In-Memory Data
self.memtable[key] = value
# 3. Check Threshold
if len(self.memtable) >= MEMTABLE_THRESHOLD:
print(f"[MemTable] Threshold reached ({len(self.memtable)} keys). Flushing...")
self.flush()
def delete(self, key: str):
"""
Soft Delete: We don't remove data immediately.
We write a 'Tombstone' record. The space is reclaimed during compaction.
"""
self.put(key, TOMBSTONE)
def flush(self):
"""
Freezes MemTable and writes it to an SSTable on disk.
"""
if not self.memtable:
return
# 1. Create a filename based on timestamp
filename = f"{int(time.time() * 1000)}.sst"
filepath = os.path.join(self.db_path, filename)
# 2. Sort keys (Sorted String Table requirement)
sorted_keys = sorted(self.memtable.keys())
# 3. Build Bloom Filter
bf = BloomFilter()
for k in sorted_keys:
bf.add(k)
# 4. Write to disk
# Format:
# [BloomFilter Len][BloomFilter JSON][Index Offset][Data...] [Index...]
with open(filepath, "wb") as f:
# Write Data Block
data_offsets = {} # Keep track of where each key starts
# Placeholder for header (we will jump back and write this later)
# We need to know where the Index starts to write it in the header
f.write(b'\x00' * 8) # Reserve 8 bytes for Index Offset
# Write Bloom Filter
bf_json = bf.to_json().encode('utf-8')
f.write(struct.pack("I", len(bf_json))) # 4 bytes for length
f.write(bf_json)
# Write Key-Values
for k in sorted_keys:
data_offsets[k] = f.tell() # Record current file position
val = self.memtable[k]
# Write: len_key(4b), key, len_val(4b), val
k_bytes = k.encode('utf-8')
v_bytes = val.encode('utf-8')
f.write(struct.pack("I", len(k_bytes)))
f.write(k_bytes)
f.write(struct.pack("I", len(v_bytes)))
f.write(v_bytes)
# Write Index Block (Sparse Index)
# In a real DB, we might only write every 10th key. Here we write all for simplicity.
index_start = f.tell()
for k, offset in data_offsets.items():
# Write: len_key(4b), key, offset(8b)
k_bytes = k.encode('utf-8')
f.write(struct.pack("I", len(k_bytes)))
f.write(k_bytes)
f.write(struct.pack("Q", offset))
# Go back to beginning and write the Index Offset
f.seek(0)
f.write(struct.pack("Q", index_start))
print(f"[Flush] Wrote {len(sorted_keys)} keys to {filename}")
# 5. Clear WAL and MemTable
self.memtable.clear()
open(self.wal_path, 'w').close() # Truncate WAL
def compact(self):
"""
Major Compaction: Merges all SSTables into one.
Eliminates overwritten keys and Tombstones.
"""
sst_files = glob.glob(os.path.join(self.db_path, "*.sst"))
if not sst_files:
return
print("[Compaction] Starting major compaction...")
all_data = {}
# 1. Read all SSTables (Oldest to Newest)
# This ensures newer keys overwrite older ones in our dict
sst_files.sort(key=os.path.getmtime)
for sst in sst_files:
# We reuse the search logic or implement a full scan.
# For simplicity, we implement a full scan helper here.
self._scan_sstable_into_dict(sst, all_data)
# 2. Filter out Tombstones
final_data = {k: v for k, v in all_data.items() if v != TOMBSTONE}
# 3. Write new SSTable
# Hack: Load into memtable and flush, then rename?
# Better: Just write a new file manually to avoid messing with current memtable
# We will temporarily hijack the memtable for this simplified example
# In a real DB, compaction happens in the background.
old_mem = self.memtable.copy()
self.memtable = final_data
# Force flush to create one big new file
self.flush()
# Restore old memtable
self.memtable = old_mem
# 4. Cleanup old files
# In real DBs, we only delete after the new one is successfully written
for sst in sst_files:
os.remove(sst)
print("[Compaction] Finished. Old files removed.")
def _recover_from_wal(self):
"""
Reads the WAL line-by-line to rebuild MemTable state.
"""
if not os.path.exists(self.wal_path):
return
print(f"[Recovery] Replaying WAL from {self.wal_path}...")
with open(self.wal_path, "r") as f:
for line in f:
try:
# Parse simplified JSON log entries
entry = json.loads(line)
key, val = entry['k'], entry['v']
self.memtable[key] = val
except ValueError:
continue # Skip corrupted lines
def _append_to_wal(self, key, value):
"""
Writes to the append-only log file.
"""
with open(self.wal_path, "a") as f:
entry = json.dumps({"k": key, "v": value})
f.write(entry + "\n")
f.flush() # Ensure OS writes to disk
def _search_sstable(self, filepath, key):
"""
Helper: Searches a specific SSTable file.
Uses Bloom Filter first to avoid reading file if possible.
"""
with open(filepath, "rb") as f:
# Read Index Offset
index_offset = struct.unpack("Q", f.read(8))[0]
# Read Bloom Filter
bf_len = struct.unpack("I", f.read(4))[0]
bf_json = f.read(bf_len).decode('utf-8')
bf = BloomFilter.from_json(bf_json)
# Optimization: Bloom Filter Check
if not bf.might_contain(key):
# print(f" [Bloom] Skipped {filepath} (Key definitely not here)")
return None
# Jump to Index
f.seek(index_offset)
# Linear Scan of Index (In real DBs, this would be binary search)
target_offset = None
while True:
# Read Key Length
chunk = f.read(4)
if not chunk: break # End of file
k_len = struct.unpack("I", chunk)[0]
k_read = f.read(k_len).decode('utf-8')
offset = struct.unpack("Q", f.read(8))[0]
if k_read == key:
target_offset = offset
break
if target_offset is None:
return None # Key in bloom filter (false positive) but not in index
# Retrieve Data
f.seek(target_offset)
# Skip key (we know it)
k_len = struct.unpack("I", f.read(4))[0]
f.read(k_len)
# Read Value
v_len = struct.unpack("I", f.read(4))[0]
return f.read(v_len).decode('utf-8')
def _scan_sstable_into_dict(self, filepath, data_dict):
"""Helper to read entire SSTable"""
with open(filepath, "rb") as f:
# Skip Index Offset (8)
f.read(8)
# Skip Bloom Filter
bf_len = struct.unpack("I", f.read(4))[0]
f.read(bf_len)
# Read Data Block until we hit the index offset
# (We know index starts at 'index_offset' read from header)
f.seek(0)
index_offset = struct.unpack("Q", f.read(8))[0]
f.seek(12 + bf_len) # Return to data start
while f.tell() < index_offset:
try:
k_len = struct.unpack("I", f.read(4))[0]
k = f.read(k_len).decode('utf-8')
v_len = struct.unpack("I", f.read(4))[0]
v = f.read(v_len).decode('utf-8')
data_dict[k] = v
except struct.error:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment