A RocksDB-inspired POC for educational purposes only, used in a workshop to teach databases and demonstrate the concepts of LSM-based databases: SSTables, WAL, Memtables and Bloomfilters.
Last active
December 23, 2025 15:15
-
-
Save aalhour/97c4bf2c7c273a4d83c77489c2507547 to your computer and use it in GitHub Desktop.
PyRocks: Database Workshop Example Code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Teaching Tips for the Workshop: | |
| 1. The "Cat" Test: While the script is running (specifically after | |
| Scenario 2), pause and have the students open the wal.log file | |
| in a text editor. They will see the JSON lines. This demystifies | |
| the log. | |
| 2. The "Hex" Test: After Scenario 3, have them use a hex editor (or | |
| hexdump -C on Mac/Linux) to view an .sst file. They will see the | |
| raw strings "Alice", "Bob" interspersed with binary length headers. | |
| 3. The "Delete" Trick: Ask them "Where did 'user2' go?" after Scenario 4. | |
| If they look at the SSTable or WAL, they will see user2 is still | |
| there, but with a __DELETED__ marker. This explains why deletes take | |
| up disk space initially. | |
| """ | |
| import shutil | |
| import os | |
| import time | |
| from pyrocks import PyRocks | |
| DB_FOLDER = "./my_workshop_db" | |
| def clean_start(): | |
| """Helper to wipe the DB folder for a fresh test""" | |
| if os.path.exists(DB_FOLDER): | |
| shutil.rmtree(DB_FOLDER) | |
| print(f"\n--- Starting Fresh in {DB_FOLDER} ---") | |
| def print_separator(title): | |
| print(f"\n{'='*10} {title} {'='*10}") | |
| # =============================================== | |
| # Scenario 1: Basic Put and Get (MemTable logic) | |
| # =============================================== | |
| clean_start() | |
| print_separator("Scenario 1: In-Memory Operations") | |
| db = PyRocks(DB_FOLDER) | |
| print("Putting keys 'user1' and 'user2'...") | |
| db.put("user1", "Alice") | |
| db.put("user2", "Bob") | |
| print(f"Get 'user1': {db.get('user1')}") | |
| print(f"Get 'user2': {db.get('user2')}") | |
| print(f"Get 'user3' (non-existent): {db.get('user3')}") | |
| # =============================================== | |
| # Scenario 2: Persistence & Recovery (WAL logic) | |
| # =============================================== | |
| print_separator("Scenario 2: Crash Recovery (WAL)") | |
| print("Simulating a crash (closing DB without flushing)...") | |
| # Note: We haven't hit the threshold (5), so data is only in MemTable and WAL. | |
| del db | |
| print("Restarting Database...") | |
| db_recovered = PyRocks(DB_FOLDER) | |
| print("Reading data after recovery:") | |
| print(f"Get 'user1': {db_recovered.get('user1')}") # Should exist due to WAL | |
| # =============================================== | |
| # Scenario 3: Flushing to SSTables | |
| # =============================================== | |
| print_separator("Scenario 3: Flushing to Disk (SSTables)") | |
| # We trigger a flush by adding enough keys to hit the threshold (set to 5 in pyrocks.py) | |
| print("Adding more keys to trigger flush...") | |
| db_recovered.put("user3", "Charlie") | |
| db_recovered.put("user4", "Dave") | |
| db_recovered.put("user5", "Eve") | |
| # Threshold hit! Check console output for "[Flush]..." | |
| print("Checking disk for .sst files...") | |
| sst_files = os.listdir(DB_FOLDER) | |
| print("Files in DB folder:", sst_files) | |
| # =============================================== | |
| # Scenario 4: Updates and Tombstones | |
| # =============================================== | |
| print_separator("Scenario 4: Updates & Deletes") | |
| print("Updating 'user1' to 'Alice_Updated'...") | |
| db_recovered.put("user1", "Alice_Updated") | |
| print("Deleting 'user2'...") | |
| db_recovered.delete("user2") | |
| print(f"Get 'user1': {db_recovered.get('user1')}") # Should be new value | |
| print(f"Get 'user2': {db_recovered.get('user2')}") # Should be None | |
| # =============================================== | |
| # Scenario 5: Compaction | |
| # =============================================== | |
| print_separator("Scenario 5: Compaction") | |
| # Force another flush to ensure we have multiple SSTables | |
| db_recovered.put("user6", "Frank") | |
| db_recovered.put("user7", "Grace") | |
| db_recovered.put("user8", "Heidi") | |
| db_recovered.put("user9", "Ivan") | |
| db_recovered.put("user10", "Judy") # Flush 2 triggers here | |
| print("Files before compaction:", [f for f in os.listdir(DB_FOLDER) if f.endswith('.sst')]) | |
| print("Running Compaction...") | |
| db_recovered.compact() | |
| print("Files after compaction:", [f for f in os.listdir(DB_FOLDER) if f.endswith('.sst')]) | |
| # Verify data still exists after compaction | |
| print(f"Get 'user1' (updated): {db_recovered.get('user1')}") | |
| print(f"Get 'user2' (deleted): {db_recovered.get('user2')}") | |
| print(f"Get 'user10' (new): {db_recovered.get('user10')}") | |
| print("\nWorkshop Demo Complete!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import os | |
| import glob | |
| import json | |
| import struct | |
| import time | |
| import hashlib | |
| from typing import Optional, List, Dict, Tuple | |
| # ========================================== | |
| # Constants & Configuration | |
| # ========================================== | |
| TOMBSTONE = "__DELETED__" # Marker for deleted keys | |
| MEMTABLE_THRESHOLD = 5 # Max keys in memory before flushing to disk (kept low for demo) | |
| class BloomFilter: | |
| """ | |
| A probabilistic data structure to test if an element is NOT in a set. | |
| '90% Truth': Real DBs use MurmurHash and bit-arrays. We use MD5 and a simple integer list | |
| to make the logic readable for students. | |
| """ | |
| def __init__(self, size=100, hash_count=3): | |
| self.size = size | |
| self.hash_count = hash_count | |
| self.bit_array = [0] * size | |
| def add(self, string): | |
| for seed in range(self.hash_count): | |
| result = int(hashlib.md5((string + str(seed)).encode()).hexdigest(), 16) | |
| self.bit_array[result % self.size] = 1 | |
| def might_contain(self, string): | |
| for seed in range(self.hash_count): | |
| result = int(hashlib.md5((string + str(seed)).encode()).hexdigest(), 16) | |
| if self.bit_array[result % self.size] == 0: | |
| return False | |
| return True | |
| def to_json(self): | |
| return json.dumps(self.bit_array) | |
| @staticmethod | |
| def from_json(json_str): | |
| bf = BloomFilter() | |
| bf.bit_array = json.loads(json_str) | |
| return bf | |
| class PyRocks: | |
| """ | |
| A simplified RocksDB clone. | |
| Architecture: | |
| 1. WAL (Write Ahead Log): Append-only file for durability. | |
| 2. MemTable: In-memory sorted dictionary. | |
| 3. SSTables: Immutable disk files (LSM-Tree levels). | |
| """ | |
| def __init__(self, db_path): | |
| self.db_path = db_path | |
| if not os.path.exists(db_path): | |
| os.makedirs(db_path) | |
| self.wal_path = os.path.join(db_path, "wal.log") | |
| self.memtable = {} # The in-memory data | |
| self.immutable_memtables = [] # Queue for flushing | |
| # Recovery: Replay WAL to restore MemTable on startup | |
| self._recover_from_wal() | |
| def get(self, key: str) -> Optional[str]: | |
| """ | |
| Public API: Read a value. | |
| Order: MemTable -> SSTables (Newest -> Oldest) | |
| """ | |
| # 1. Check MemTable (Fastest) | |
| if key in self.memtable: | |
| val = self.memtable[key] | |
| if val == TOMBSTONE: | |
| return None # Key was deleted | |
| return val | |
| # 2. Check SSTables on disk (Newest first) | |
| sst_files = sorted(glob.glob(os.path.join(self.db_path, "*.sst")), reverse=True) | |
| for sst_file in sst_files: | |
| value = self._search_sstable(sst_file, key) | |
| if value is not None: | |
| if value == TOMBSTONE: | |
| return None | |
| return value | |
| return None | |
| def put(self, key: str, value: str): | |
| """ | |
| Public API: Writes a key-value pair. | |
| 1. Write to WAL (Durability) | |
| 2. Write to MemTable (Speed) | |
| 3. Check if MemTable needs flushing | |
| """ | |
| # 1. Write Ahead Log | |
| self._append_to_wal(key, value) | |
| # 2. Update In-Memory Data | |
| self.memtable[key] = value | |
| # 3. Check Threshold | |
| if len(self.memtable) >= MEMTABLE_THRESHOLD: | |
| print(f"[MemTable] Threshold reached ({len(self.memtable)} keys). Flushing...") | |
| self.flush() | |
| def delete(self, key: str): | |
| """ | |
| Soft Delete: We don't remove data immediately. | |
| We write a 'Tombstone' record. The space is reclaimed during compaction. | |
| """ | |
| self.put(key, TOMBSTONE) | |
| def flush(self): | |
| """ | |
| Freezes MemTable and writes it to an SSTable on disk. | |
| """ | |
| if not self.memtable: | |
| return | |
| # 1. Create a filename based on timestamp | |
| filename = f"{int(time.time() * 1000)}.sst" | |
| filepath = os.path.join(self.db_path, filename) | |
| # 2. Sort keys (Sorted String Table requirement) | |
| sorted_keys = sorted(self.memtable.keys()) | |
| # 3. Build Bloom Filter | |
| bf = BloomFilter() | |
| for k in sorted_keys: | |
| bf.add(k) | |
| # 4. Write to disk | |
| # Format: | |
| # [BloomFilter Len][BloomFilter JSON][Index Offset][Data...] [Index...] | |
| with open(filepath, "wb") as f: | |
| # Write Data Block | |
| data_offsets = {} # Keep track of where each key starts | |
| # Placeholder for header (we will jump back and write this later) | |
| # We need to know where the Index starts to write it in the header | |
| f.write(b'\x00' * 8) # Reserve 8 bytes for Index Offset | |
| # Write Bloom Filter | |
| bf_json = bf.to_json().encode('utf-8') | |
| f.write(struct.pack("I", len(bf_json))) # 4 bytes for length | |
| f.write(bf_json) | |
| # Write Key-Values | |
| for k in sorted_keys: | |
| data_offsets[k] = f.tell() # Record current file position | |
| val = self.memtable[k] | |
| # Write: len_key(4b), key, len_val(4b), val | |
| k_bytes = k.encode('utf-8') | |
| v_bytes = val.encode('utf-8') | |
| f.write(struct.pack("I", len(k_bytes))) | |
| f.write(k_bytes) | |
| f.write(struct.pack("I", len(v_bytes))) | |
| f.write(v_bytes) | |
| # Write Index Block (Sparse Index) | |
| # In a real DB, we might only write every 10th key. Here we write all for simplicity. | |
| index_start = f.tell() | |
| for k, offset in data_offsets.items(): | |
| # Write: len_key(4b), key, offset(8b) | |
| k_bytes = k.encode('utf-8') | |
| f.write(struct.pack("I", len(k_bytes))) | |
| f.write(k_bytes) | |
| f.write(struct.pack("Q", offset)) | |
| # Go back to beginning and write the Index Offset | |
| f.seek(0) | |
| f.write(struct.pack("Q", index_start)) | |
| print(f"[Flush] Wrote {len(sorted_keys)} keys to {filename}") | |
| # 5. Clear WAL and MemTable | |
| self.memtable.clear() | |
| open(self.wal_path, 'w').close() # Truncate WAL | |
| def compact(self): | |
| """ | |
| Major Compaction: Merges all SSTables into one. | |
| Eliminates overwritten keys and Tombstones. | |
| """ | |
| sst_files = glob.glob(os.path.join(self.db_path, "*.sst")) | |
| if not sst_files: | |
| return | |
| print("[Compaction] Starting major compaction...") | |
| all_data = {} | |
| # 1. Read all SSTables (Oldest to Newest) | |
| # This ensures newer keys overwrite older ones in our dict | |
| sst_files.sort(key=os.path.getmtime) | |
| for sst in sst_files: | |
| # We reuse the search logic or implement a full scan. | |
| # For simplicity, we implement a full scan helper here. | |
| self._scan_sstable_into_dict(sst, all_data) | |
| # 2. Filter out Tombstones | |
| final_data = {k: v for k, v in all_data.items() if v != TOMBSTONE} | |
| # 3. Write new SSTable | |
| # Hack: Load into memtable and flush, then rename? | |
| # Better: Just write a new file manually to avoid messing with current memtable | |
| # We will temporarily hijack the memtable for this simplified example | |
| # In a real DB, compaction happens in the background. | |
| old_mem = self.memtable.copy() | |
| self.memtable = final_data | |
| # Force flush to create one big new file | |
| self.flush() | |
| # Restore old memtable | |
| self.memtable = old_mem | |
| # 4. Cleanup old files | |
| # In real DBs, we only delete after the new one is successfully written | |
| for sst in sst_files: | |
| os.remove(sst) | |
| print("[Compaction] Finished. Old files removed.") | |
| def _recover_from_wal(self): | |
| """ | |
| Reads the WAL line-by-line to rebuild MemTable state. | |
| """ | |
| if not os.path.exists(self.wal_path): | |
| return | |
| print(f"[Recovery] Replaying WAL from {self.wal_path}...") | |
| with open(self.wal_path, "r") as f: | |
| for line in f: | |
| try: | |
| # Parse simplified JSON log entries | |
| entry = json.loads(line) | |
| key, val = entry['k'], entry['v'] | |
| self.memtable[key] = val | |
| except ValueError: | |
| continue # Skip corrupted lines | |
| def _append_to_wal(self, key, value): | |
| """ | |
| Writes to the append-only log file. | |
| """ | |
| with open(self.wal_path, "a") as f: | |
| entry = json.dumps({"k": key, "v": value}) | |
| f.write(entry + "\n") | |
| f.flush() # Ensure OS writes to disk | |
| def _search_sstable(self, filepath, key): | |
| """ | |
| Helper: Searches a specific SSTable file. | |
| Uses Bloom Filter first to avoid reading file if possible. | |
| """ | |
| with open(filepath, "rb") as f: | |
| # Read Index Offset | |
| index_offset = struct.unpack("Q", f.read(8))[0] | |
| # Read Bloom Filter | |
| bf_len = struct.unpack("I", f.read(4))[0] | |
| bf_json = f.read(bf_len).decode('utf-8') | |
| bf = BloomFilter.from_json(bf_json) | |
| # Optimization: Bloom Filter Check | |
| if not bf.might_contain(key): | |
| # print(f" [Bloom] Skipped {filepath} (Key definitely not here)") | |
| return None | |
| # Jump to Index | |
| f.seek(index_offset) | |
| # Linear Scan of Index (In real DBs, this would be binary search) | |
| target_offset = None | |
| while True: | |
| # Read Key Length | |
| chunk = f.read(4) | |
| if not chunk: break # End of file | |
| k_len = struct.unpack("I", chunk)[0] | |
| k_read = f.read(k_len).decode('utf-8') | |
| offset = struct.unpack("Q", f.read(8))[0] | |
| if k_read == key: | |
| target_offset = offset | |
| break | |
| if target_offset is None: | |
| return None # Key in bloom filter (false positive) but not in index | |
| # Retrieve Data | |
| f.seek(target_offset) | |
| # Skip key (we know it) | |
| k_len = struct.unpack("I", f.read(4))[0] | |
| f.read(k_len) | |
| # Read Value | |
| v_len = struct.unpack("I", f.read(4))[0] | |
| return f.read(v_len).decode('utf-8') | |
| def _scan_sstable_into_dict(self, filepath, data_dict): | |
| """Helper to read entire SSTable""" | |
| with open(filepath, "rb") as f: | |
| # Skip Index Offset (8) | |
| f.read(8) | |
| # Skip Bloom Filter | |
| bf_len = struct.unpack("I", f.read(4))[0] | |
| f.read(bf_len) | |
| # Read Data Block until we hit the index offset | |
| # (We know index starts at 'index_offset' read from header) | |
| f.seek(0) | |
| index_offset = struct.unpack("Q", f.read(8))[0] | |
| f.seek(12 + bf_len) # Return to data start | |
| while f.tell() < index_offset: | |
| try: | |
| k_len = struct.unpack("I", f.read(4))[0] | |
| k = f.read(k_len).decode('utf-8') | |
| v_len = struct.unpack("I", f.read(4))[0] | |
| v = f.read(v_len).decode('utf-8') | |
| data_dict[k] = v | |
| except struct.error: | |
| break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment