Last active
September 15, 2023 09:24
-
-
Save rakeshopensource/0c7f20a8a7725da386b662a72da5f132 to your computer and use it in GitHub Desktop.
Explore How NoSQL databases work with simple Java simulation of LSM trees. This lets you see how data storage and retrieval happens, all running in-memory for easy learning and experimentation. Get a hands-on understanding of LSM trees without the complexity of disk storage.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.rakeshopensource.systemdesign; | |
import java.util.*; | |
import java.util.function.Function; | |
class MemTable<K, V> { | |
private final Map<K, V> data; | |
public MemTable() { | |
data = new HashMap<>(); | |
} | |
public void put(K key, V value) { | |
data.put(key, value); | |
} | |
public V get(K key) { | |
return data.get(key); | |
} | |
public Map<K, V> getData() { | |
return data; | |
} | |
} | |
class SSTable<K, V> { | |
private final SortedMap<K, V> data; | |
private final BloomFilter<K> bloomFilter; // Each ssTable will have its own Bloom filter | |
public SSTable(Map<K, V> data, BloomFilter<K> bloomFilter) { | |
this.data = new TreeMap<>(data); | |
this.bloomFilter = bloomFilter; | |
} | |
public SSTable(List<SSTable<K, V>> sstables, BloomFilter<K> bloomFilter ) { | |
this.data = new TreeMap<>(); | |
this.bloomFilter = bloomFilter; | |
for (SSTable<K, V> sstable : sstables) { | |
this.data.putAll(sstable.getData()); | |
sstable.getData().keySet().forEach(this.bloomFilter::add); | |
} | |
} | |
public V get(K key) { | |
return data.get(key); | |
} | |
public Map<K, V> getData() { | |
return data; | |
} | |
public BloomFilter<K> getBloomFilter() { | |
return bloomFilter; | |
} | |
} | |
class BloomFilter<T> { | |
private final int size; | |
private final BitSet bitSet; | |
private final List<Function<T, Integer>> hashFunctions; | |
public BloomFilter(int size, List<Function<T, Integer>> hashFunctions) { | |
this.size = size; | |
this.bitSet = new BitSet(this.size); | |
this.hashFunctions = hashFunctions; | |
} | |
public void add(T element){ | |
for(Function<T, Integer> hashFunction : hashFunctions){ | |
int hash = Math.abs(hashFunction.apply(element)) % size; | |
bitSet.set(hash); | |
} | |
} | |
public boolean mightContain(T element){ | |
for(Function<T, Integer> hashFunction : hashFunctions){ | |
int hash = Math.abs(hashFunction.apply(element)) % size; | |
if(!bitSet.get(hash)){ | |
return false; | |
} | |
} | |
return true; | |
} | |
} | |
class LSMTree<K, V> { | |
private MemTable<K, V> memTable; | |
private List<SSTable<K, V>> sstables; | |
private final int maxMemTableThreashold; | |
private final int sstableThreshold; | |
private final List<Function<K, Integer>> hashFunctions; | |
private final static int BLOOM_FILTER_SIZE = 640; // multiple of 64 | |
public LSMTree(int maxMemTableSize, int sstableThreshold, List<Function<K, Integer>> hashFunctions) { | |
this.memTable = new MemTable<>(); | |
this.sstables = new ArrayList<>(); | |
this.maxMemTableThreashold = maxMemTableSize; | |
this.sstableThreshold = sstableThreshold; | |
this.hashFunctions = hashFunctions; | |
} | |
public void put(K key, V value) { | |
memTable.put(key, value); | |
if (memTable.getData().size() >= maxMemTableThreashold) { | |
compact(); | |
} | |
} | |
public V get(K key) { | |
V value = memTable.get(key); | |
if (value != null) { | |
return value; | |
} | |
//Start scan from latest to old. Use bloom filter first | |
for (int i = sstables.size() - 1; i >= 0; i--) { | |
SSTable<K, V> sstable = sstables.get(i); | |
if(sstable.getBloomFilter().mightContain(key)) { | |
value = sstable.get(key); | |
if (value != null) { | |
return value; | |
} | |
} | |
} | |
return null; | |
} | |
private void compact() { | |
BloomFilter<K> newBloomFilter = new BloomFilter<>(BLOOM_FILTER_SIZE, hashFunctions); | |
sstables.add(new SSTable<>(memTable.getData(), newBloomFilter)); | |
if (sstables.size() >= sstableThreshold) { | |
List<SSTable<K, V>> newSSTables = compactSSTables(sstables); | |
sstables.clear(); // clear all sstable | |
sstables.addAll(newSSTables); | |
} | |
memTable = new MemTable<>(); // clean up after dumping ssTables to disk | |
} | |
private List<SSTable<K, V>> compactSSTables(List<SSTable<K, V>> sstables) { | |
BloomFilter<K> newBloomFilter = new BloomFilter<>(BLOOM_FILTER_SIZE, hashFunctions); | |
SSTable<K, V> compactedSSTable = new SSTable<>(sstables, newBloomFilter); | |
List<SSTable<K, V>> newSSTables = new ArrayList<>(); | |
newSSTables.add(compactedSSTable); | |
return newSSTables; | |
} | |
} | |
public class LSMTreeDriver { | |
public static void main(String[] args) { | |
List<Function<String, Integer>> hashFunctions = List.of( | |
String::hashCode, | |
String::length, | |
s -> s.hashCode() * 31 | |
); | |
LSMTree<String, String> lsmTree = new LSMTree<>(2, 2, hashFunctions); | |
lsmTree.put("1", "one"); | |
lsmTree.put("2", "two"); | |
lsmTree.put("3", "three"); | |
lsmTree.put("4", "four"); | |
lsmTree.put("5", "five"); | |
System.out.println(lsmTree.get("1")); // Output: one | |
System.out.println(lsmTree.get("6")); // Output: null | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment