Skip to content

Instantly share code, notes, and snippets.

@rakeshopensource
Last active September 15, 2023 09:24
Show Gist options
  • Save rakeshopensource/0c7f20a8a7725da386b662a72da5f132 to your computer and use it in GitHub Desktop.
Save rakeshopensource/0c7f20a8a7725da386b662a72da5f132 to your computer and use it in GitHub Desktop.
Explore How NoSQL databases work with simple Java simulation of LSM trees. This lets you see how data storage and retrieval happens, all running in-memory for easy learning and experimentation. Get a hands-on understanding of LSM trees without the complexity of disk storage.
package org.rakeshopensource.systemdesign;
import java.util.*;
import java.util.function.Function;
class MemTable<K, V> {
private final Map<K, V> data;
public MemTable() {
data = new HashMap<>();
}
public void put(K key, V value) {
data.put(key, value);
}
public V get(K key) {
return data.get(key);
}
public Map<K, V> getData() {
return data;
}
}
class SSTable<K, V> {
private final SortedMap<K, V> data;
private final BloomFilter<K> bloomFilter; // Each ssTable will have its own Bloom filter
public SSTable(Map<K, V> data, BloomFilter<K> bloomFilter) {
this.data = new TreeMap<>(data);
this.bloomFilter = bloomFilter;
}
public SSTable(List<SSTable<K, V>> sstables, BloomFilter<K> bloomFilter ) {
this.data = new TreeMap<>();
this.bloomFilter = bloomFilter;
for (SSTable<K, V> sstable : sstables) {
this.data.putAll(sstable.getData());
sstable.getData().keySet().forEach(this.bloomFilter::add);
}
}
public V get(K key) {
return data.get(key);
}
public Map<K, V> getData() {
return data;
}
public BloomFilter<K> getBloomFilter() {
return bloomFilter;
}
}
class BloomFilter<T> {
private final int size;
private final BitSet bitSet;
private final List<Function<T, Integer>> hashFunctions;
public BloomFilter(int size, List<Function<T, Integer>> hashFunctions) {
this.size = size;
this.bitSet = new BitSet(this.size);
this.hashFunctions = hashFunctions;
}
public void add(T element){
for(Function<T, Integer> hashFunction : hashFunctions){
int hash = Math.abs(hashFunction.apply(element)) % size;
bitSet.set(hash);
}
}
public boolean mightContain(T element){
for(Function<T, Integer> hashFunction : hashFunctions){
int hash = Math.abs(hashFunction.apply(element)) % size;
if(!bitSet.get(hash)){
return false;
}
}
return true;
}
}
class LSMTree<K, V> {
private MemTable<K, V> memTable;
private List<SSTable<K, V>> sstables;
private final int maxMemTableThreashold;
private final int sstableThreshold;
private final List<Function<K, Integer>> hashFunctions;
private final static int BLOOM_FILTER_SIZE = 640; // multiple of 64
public LSMTree(int maxMemTableSize, int sstableThreshold, List<Function<K, Integer>> hashFunctions) {
this.memTable = new MemTable<>();
this.sstables = new ArrayList<>();
this.maxMemTableThreashold = maxMemTableSize;
this.sstableThreshold = sstableThreshold;
this.hashFunctions = hashFunctions;
}
public void put(K key, V value) {
memTable.put(key, value);
if (memTable.getData().size() >= maxMemTableThreashold) {
compact();
}
}
public V get(K key) {
V value = memTable.get(key);
if (value != null) {
return value;
}
//Start scan from latest to old. Use bloom filter first
for (int i = sstables.size() - 1; i >= 0; i--) {
SSTable<K, V> sstable = sstables.get(i);
if(sstable.getBloomFilter().mightContain(key)) {
value = sstable.get(key);
if (value != null) {
return value;
}
}
}
return null;
}
private void compact() {
BloomFilter<K> newBloomFilter = new BloomFilter<>(BLOOM_FILTER_SIZE, hashFunctions);
sstables.add(new SSTable<>(memTable.getData(), newBloomFilter));
if (sstables.size() >= sstableThreshold) {
List<SSTable<K, V>> newSSTables = compactSSTables(sstables);
sstables.clear(); // clear all sstable
sstables.addAll(newSSTables);
}
memTable = new MemTable<>(); // clean up after dumping ssTables to disk
}
private List<SSTable<K, V>> compactSSTables(List<SSTable<K, V>> sstables) {
BloomFilter<K> newBloomFilter = new BloomFilter<>(BLOOM_FILTER_SIZE, hashFunctions);
SSTable<K, V> compactedSSTable = new SSTable<>(sstables, newBloomFilter);
List<SSTable<K, V>> newSSTables = new ArrayList<>();
newSSTables.add(compactedSSTable);
return newSSTables;
}
}
public class LSMTreeDriver {
public static void main(String[] args) {
List<Function<String, Integer>> hashFunctions = List.of(
String::hashCode,
String::length,
s -> s.hashCode() * 31
);
LSMTree<String, String> lsmTree = new LSMTree<>(2, 2, hashFunctions);
lsmTree.put("1", "one");
lsmTree.put("2", "two");
lsmTree.put("3", "three");
lsmTree.put("4", "four");
lsmTree.put("5", "five");
System.out.println(lsmTree.get("1")); // Output: one
System.out.println(lsmTree.get("6")); // Output: null
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment