Last active
August 23, 2024 01:22
-
-
Save Alwinfy/b96bec5ed5def163b2c0a3ea3af670d8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
import java.util.Optional; | |
import java.util.Map; | |
import java.util.Iterator; | |
import java.util.AbstractMap.SimpleImmutableEntry; | |
public class Hamt<K, V> implements Iterable<Map.Entry<K, V>> { | |
static <K> boolean equals(K left, K right) { | |
return left.equals(right); | |
} | |
static <K> int hash(K key) { | |
return key.hashCode(); | |
} | |
// This is a simple HAMT, see https://github.com/python/cpython/blob/main/Python/hamt.c | |
// The only difference is that "collision nodes" are LeafNodes, which can be either | |
// SingleNodes or CollisionNodes (the former for better performance, since Java arrays aren't inline). | |
// | |
// This demo impl uses Hamt.equals() for testing keys. Probably want to swap that for iota equality. | |
sealed interface HamtNode<K, V> { | |
HamtNode<K, V> assoc(int hash, K key, V val); | |
Optional<V> get(int hash, K key); | |
HamtNode<K, V> dissoc(int hash, K key); | |
int size(); | |
} | |
// Array node: store children "densely" (when there are >16 children); size is the number of nonnull children | |
static record ArrayNode<K, V>(int size, HamtNode<K, V>[] children) implements HamtNode<K, V> { | |
@Override | |
public HamtNode<K, V> assoc(int hash, K key, V val) { | |
int next = hash >>> 5; | |
hash &= 0x1f; | |
var child = children[hash]; | |
if (child != null) { | |
var newChild = child.assoc(next, key, val); | |
if (newChild == child) { | |
return this; | |
} | |
var newChildren = Arrays.copyOf(children, children.length); | |
newChildren[hash] = newChild; | |
return new ArrayNode<>(size, newChildren); | |
} | |
var newChildren = Arrays.copyOf(children, children.length); | |
newChildren[hash] = new SingleNode<>(next, key, val); | |
return new ArrayNode<>(size + 1, newChildren); | |
} | |
@Override | |
public Optional<V> get(int hash, K key) { | |
int next = hash >>> 5; | |
var child = children[hash & 0x1f]; | |
return child == null ? Optional.empty() : child.get(next, key); | |
} | |
@Override | |
public HamtNode<K, V> dissoc(int hash, K key) { | |
int next = hash >>> 5; | |
hash &= 0x1f; | |
var child = children[hash]; | |
if (child == null) { | |
return this; | |
} | |
var newChild = child.dissoc(next, key); | |
if (newChild == child) { | |
return this; | |
} | |
// TODO: if nchildren = 16 && newChild == null, downgrade? | |
if (size <= 16 && newChild == null) { | |
int pop = 0, index = 0; | |
@SuppressWarnings("unchecked") | |
var newChildren = (HamtNode<K, V>[]) new HamtNode<?, ?>[size - 1]; | |
for (int i = 0; i < children.length; i++) { | |
if (i != hash && children[i] != null) { | |
pop |= 1 << i; | |
newChildren[index++] = children[i]; | |
} | |
} | |
assert (size - 1 == index); | |
return new HamNode<>(pop, newChildren); | |
} | |
var newChildren = Arrays.copyOf(children, children.length); | |
newChildren[hash] = newChild; | |
return new ArrayNode<>(size - (newChild == null ? 1 : 0), newChildren); | |
} | |
@Override | |
public int size() { | |
int count = 0; | |
for (int i = 0; i < children.length; i++) { | |
if (children[i] != null) { | |
count += children[i].size(); | |
} | |
} | |
return count; | |
} | |
@Override public String toString() { return "A[" + Arrays.toString(children) + "]"; } | |
} | |
// Array node: store children "sparsely" (<16 children); pop is a bitmap of the 32 children this can have | |
static record HamNode<K, V>(int pop, HamtNode<K, V>[] children) implements HamtNode<K, V> { | |
@Override | |
public HamtNode<K, V> assoc(int hash, K key, V val) { | |
int next = hash >>> 5; | |
hash &= 0x1f; | |
int index = indexOf(pop, hash); | |
if (hasHash(pop, hash)) { | |
var child = children[index]; | |
var newChild = child.assoc(next, key, val); | |
if (child == newChild) { | |
return this; | |
} | |
var newChildren = Arrays.copyOf(children, children.length); | |
newChildren[index] = newChild; | |
return new HamNode<>(pop, newChildren); | |
} | |
if (children.length >= 15) { | |
@SuppressWarnings("unchecked") | |
var arrayEnts = (HamtNode<K, V>[]) new HamtNode<?, ?>[32]; | |
int work = pop, inputPos = 0; | |
while (work != 0) { | |
int outputPos = Integer.numberOfTrailingZeros(work); | |
work &= work - 1; // remove lowest 1 | |
arrayEnts[outputPos] = children[inputPos++]; | |
} | |
arrayEnts[hash] = new SingleNode<>(next, key, val); | |
return new ArrayNode<>(1 + inputPos, arrayEnts); | |
} | |
@SuppressWarnings("unchecked") | |
var newChildren = (HamtNode<K, V>[]) new HamtNode<?, ?>[children.length + 1]; | |
System.arraycopy(children, 0, newChildren, 0, index); | |
System.arraycopy(children, index, newChildren, index + 1, children.length - index); | |
newChildren[index] = new SingleNode<>(next, key, val); | |
return new HamNode<>(pop | 1 << hash, newChildren); | |
} | |
@Override | |
public Optional<V> get(int hash, K key) { | |
return hasHash(pop, hash & 0x1f) ? children[indexOf(pop, hash & 0x1f)].get(hash >>> 5, key) : Optional.empty(); | |
} | |
@Override | |
public HamtNode<K, V> dissoc(int hash, K key) { | |
int next = hash >>> 5; | |
hash &= 0x1f; | |
if (!hasHash(pop, hash)) { | |
return this; | |
} | |
int index = indexOf(pop, hash); | |
var child = children[index]; | |
var newChild = child.dissoc(next, key); | |
if (child == newChild) { | |
return this; | |
} | |
if (newChild != null) { | |
var newChildren = Arrays.copyOf(children, children.length); | |
newChildren[index] = newChild; | |
return new HamNode<>(pop, newChildren); | |
} | |
if (children.length == 1) { | |
return null; | |
} | |
int newPop = pop & ~(1 << hash); | |
if (children.length == 2) { | |
int remainingHash = Integer.numberOfTrailingZeros(newPop); | |
var childNode = children[indexOf(pop, remainingHash)]; | |
if (childNode instanceof LeafNode<K, V> ln) { | |
return ln.withNewHash(ln.tailHash() << 5 | remainingHash); | |
} | |
} | |
@SuppressWarnings("unchecked") | |
var newChildren = (HamtNode<K, V>[]) new HamtNode<?, ?>[children.length - 1]; | |
System.arraycopy(children, 0, newChildren, 0, index); | |
System.arraycopy(children, index + 1, newChildren, index, children.length - index - 1); | |
return new HamNode<>(newPop, newChildren); | |
} | |
@Override | |
public int size() { | |
int count = 0; | |
for (int i = 0; i < children.length; i++) { | |
count += children[i].size(); | |
} | |
return count; | |
} | |
static boolean hasHash(int pop, int hash) { | |
int offset = 1 << hash; | |
return (pop & offset) != 0; | |
} | |
static int indexOf(int pop, int hash) { | |
int offset = 1 << hash; | |
return Integer.bitCount(pop & (offset - 1)); | |
} | |
@Override public String toString() { return "H[" + Integer.toString(pop, 2) + ", " + Arrays.toString(children) + "]"; } | |
} | |
// Leaf node: stores data directly | |
sealed interface LeafNode<K, V> extends HamtNode<K, V> { | |
/** The "rest of the hash" that was leftover after reaching this point */ | |
int tailHash(); | |
/** add a key, val pair to this leaf node directly */ | |
HamtNode<K, V> doAssoc(K key, V val); | |
/** Reconstruct this node with a new "tail hash" */ | |
LeafNode<K, V> withNewHash(int newHash); | |
@Override | |
default HamtNode<K, V> assoc(int hash, K key, V val) { | |
if (hash == tailHash()) { | |
return doAssoc(key, val); | |
} | |
return assocRecursive(hash, tailHash(), key, val); | |
} | |
default HamtNode<K, V> assocRecursive(int hash, int tailHash, K key, V val) { | |
int nextHash = hash >>> 5; | |
int nextTailHash = tailHash >>> 5; | |
hash &= 0x1f; | |
tailHash &= 0x1f; | |
if (hash == tailHash) { | |
@SuppressWarnings("unchecked") | |
var child = (HamtNode<K, V>[]) new HamtNode<?, ?>[] {assocRecursive(nextHash, nextTailHash, key, val)}; | |
return new HamNode<>(1 << hash, child); | |
} | |
var existingNode = withNewHash(nextTailHash); | |
var newNode = new SingleNode<>(nextHash, key, val); | |
var left = hash < tailHash ? newNode : existingNode; | |
var right = hash < tailHash ? existingNode : newNode; | |
@SuppressWarnings("unchecked") | |
var child = (HamtNode<K, V>[]) new HamtNode<?, ?>[] {left, right}; | |
return new HamNode<>(1 << hash | 1 << tailHash, child); | |
} | |
/** given an int [0..size()), return the key/val pair */ | |
Map.Entry<K, V> fetch(int index); | |
} | |
static record SingleNode<K, V>(@Override int tailHash, K key, V value) implements LeafNode<K, V> { | |
@Override | |
public SingleNode<K, V> withNewHash(int newHash) { | |
return new SingleNode<>(newHash, key, value); | |
} | |
@Override | |
public HamtNode<K, V> doAssoc(K key, V val) { | |
return Hamt.equals(this.key, key) ? new SingleNode<>(tailHash, key, val) : | |
new CollisionNode<>(tailHash, new Object[] {this.key, this.value, key, val}); | |
} | |
@Override | |
public Optional<V> get(int hash, K key) { | |
if (tailHash == hash && Hamt.equals(this.key, key)) { | |
return Optional.of(value); | |
} | |
return Optional.empty(); | |
} | |
@Override | |
public HamtNode<K, V> dissoc(int hash, K key) { | |
if (tailHash == hash && Hamt.equals(this.key, key)) { | |
return null; | |
} | |
return this; | |
} | |
@Override public int size() { return 1; } | |
@Override public Map.Entry<K, V> fetch(int index) { return new SimpleImmutableEntry<>(key, value); } | |
} | |
// storing keys and vals as adjacent objects :/ | |
// unsexy and untypeful but it is what it is | |
static record CollisionNode<K, V>(@Override int tailHash, Object[] entries) implements LeafNode<K, V> { | |
@Override | |
public CollisionNode<K, V> withNewHash(int newHash) { | |
return new CollisionNode<>(newHash, entries); | |
} | |
@Override | |
public HamtNode<K, V> doAssoc(K key, V val) { | |
for (int i = 0; i < entries.length; i += 2) { | |
@SuppressWarnings("unchecked") | |
var thisKey = (K) entries[i]; | |
if (Hamt.equals(thisKey, key)) { | |
var newVals = Arrays.copyOf(entries, entries.length); | |
newVals[i + 1] = val; | |
return new CollisionNode<>(tailHash, newVals); | |
} | |
} | |
var newVals = Arrays.copyOf(entries, entries.length + 2); | |
newVals[entries.length] = key; | |
newVals[entries.length + 1] = val; | |
return new CollisionNode<>(tailHash, newVals); | |
} | |
@Override | |
public Optional<V> get(int hash, K key) { | |
if (hash == tailHash) { | |
for (int i = 0; i < entries.length; i+=2) { | |
@SuppressWarnings("unchecked") | |
var thisKey = (K) entries[i]; | |
if (Hamt.equals(thisKey, key)) { | |
@SuppressWarnings("unchecked") | |
var value = (V) entries[i + 1]; | |
return Optional.of(value); | |
} | |
} | |
} | |
return Optional.empty(); | |
} | |
@Override | |
public HamtNode<K, V> dissoc(int hash, K key) { | |
if (hash == tailHash) { | |
for (int i = 0; i < entries.length; i+=2) { | |
@SuppressWarnings("unchecked") | |
var thisKey = (K) entries[i]; | |
if (Hamt.equals(thisKey, key)) { | |
if (entries.length == 4) { | |
@SuppressWarnings("unchecked") | |
var otherKey = (K) entries[i ^ 2]; | |
@SuppressWarnings("unchecked") | |
var otherVal = (V) entries[i ^ 3]; | |
return new SingleNode<>(tailHash, otherKey, otherVal); | |
} | |
var newEntries = new Object[entries.length - 2]; | |
System.arraycopy(entries, 0, newEntries, 0, i); | |
System.arraycopy(entries, i + 2, newEntries, i, entries.length - i - 2); | |
return new CollisionNode<>(tailHash, newEntries); | |
} | |
} | |
} | |
return this; | |
} | |
@Override public int size() { return entries.length / 2; } | |
@SuppressWarnings("unchecked") | |
@Override public Map.Entry<K, V> fetch(int index) { | |
return new SimpleImmutableEntry<>((K) entries[2 * index], (V) entries[2 * index + 1]); | |
} | |
} | |
private HamtNode<K, V> root; | |
private Hamt(HamtNode<K, V> root) { | |
this.root = root; | |
} | |
@SuppressWarnings("unchecked") | |
public static<K, V> Hamt<K, V> empty() { | |
return new Hamt<>(null); | |
} | |
public Hamt<K, V> assoc(K key, V value) { | |
return new Hamt<>(root != null ? root.assoc(Hamt.hash(key), key, value) : new SingleNode<>(Hamt.hash(key), key, value)); | |
} | |
public Hamt<K, V> dissoc(K key) { | |
return new Hamt<>(root == null ? null : root.dissoc(Hamt.hash(key), key)); | |
} | |
public Optional<V> get(K key) { | |
return root == null ? Optional.empty() : root.get(Hamt.hash(key), key); | |
} | |
public int size() { | |
return root == null ? 0 : root.size(); | |
} | |
public Hamt<K, V> assocAll(Iterable<Map.Entry<? extends K, ? extends V>> values) { | |
var out = this; | |
for (var entry : values) { | |
out = out.assoc(entry.getKey(), entry.getValue()); | |
} | |
return out; | |
} | |
static<K, V> Hamt<K, V> ofIterable(Iterable<Map.Entry<? extends K, ? extends V>> values) { | |
return Hamt.<K, V>empty().assocAll(values); | |
} | |
HamtNode<K, V> root() { return root; } | |
// iterator over a HAMT | |
@Override | |
public Iterator<Map.Entry<K, V>> iterator() { | |
return new Iterator<Map.Entry<K, V>>() { | |
private int depth = 0; | |
@SuppressWarnings("unchecked") | |
private HamtNode<K, V>[] nodeStack = (HamtNode<K, V>[])new HamtNode<?, ?>[7]; | |
private int[] progressStack = new int[7]; | |
private LeafNode<K, V> leafNode; | |
private int leafProgress; | |
{ | |
if (root == null) { | |
depth = -1; | |
} else { | |
nodeStack[0] = root; | |
advance(); | |
} | |
} | |
@Override | |
public boolean hasNext() { | |
return leafProgress > 0 || depth >= 0; | |
} | |
public Map.Entry<K, V> next() { | |
var value = leafNode.fetch(--leafProgress); | |
if (leafProgress == 0) { | |
advance(); | |
} | |
return value; | |
} | |
void advance() { | |
loop: while (depth >= 0) switch (nodeStack[depth]) { | |
case ArrayNode<K, V> arr -> { | |
for (int i = progressStack[depth]; i < arr.children().length; i++) { | |
if (arr.children()[i] != null) { | |
progressStack[depth] = i + 1; | |
nodeStack[++depth] = arr.children()[i]; | |
progressStack[depth] = 0; | |
continue loop; | |
} | |
} | |
depth--; | |
} | |
case HamNode<K, V> ham -> { | |
int i = progressStack[depth]++; | |
if (i < ham.children().length) { | |
nodeStack[++depth] = ham.children()[i]; | |
progressStack[depth] = 0; | |
continue loop; | |
} | |
depth--; | |
} | |
case LeafNode<K, V> leaf -> { | |
leafNode = leaf; | |
leafProgress = leaf.size(); | |
depth--; | |
return; | |
} | |
} | |
} | |
}; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.HashMap; | |
import java.util.Objects; | |
import java.util.Random; | |
public class HamtTest { | |
public Hamt<Integer, Integer> node = Hamt.empty(); | |
public HashMap<Integer, Integer> stress = new HashMap<>(); | |
static final int[] pool; | |
static { | |
var poolIn = new int[5000]; | |
var random = new Random(); | |
for (int i = 0; i < poolIn.length; i++) { | |
poolIn[i] = random.nextInt(); | |
} | |
pool = poolIn; | |
} | |
public void verify(int iter) { | |
for (int i = 0; i < pool.length; i++) { | |
var one = node.get(pool[i]).orElse(null); | |
var two = stress.get(pool[i]); | |
if (!Objects.equals(one, two)) { | |
throw new IllegalStateException("Iter " + iter + ": mismatch for key " + pool[i] + ": expected " + two + ", got " + one); | |
} | |
} | |
int size = node.size(), computedSize = 0; | |
for (var entry : node) { | |
computedSize++; | |
if (!entry.getValue().equals(stress.get(entry.getKey()))) { | |
throw new IllegalStateException("Iter " + iter + ": entry mismatch"); | |
} | |
} | |
if (size != computedSize) { | |
throw new IllegalStateException("Iter " + iter + ": mismatch computed size: want " + computedSize + ", got " + size); | |
} | |
if (size != stress.size()) { | |
throw new IllegalStateException("Iter " + iter + ": mismatch size: want " + stress.size() + ", got " + size); | |
} | |
} | |
public void doMain() { | |
try { | |
for (int i = 1000; i < 3000; i++) { | |
node = node.assoc(pool[i], i); | |
stress.put(pool[i], i); | |
verify(i); | |
} | |
for (int i = 0000; i < 2000; i++) { | |
node = node.dissoc(pool[i]); | |
stress.remove(pool[i]); | |
verify(i); | |
} | |
{ | |
var node2 = node; | |
var stress2 = new HashMap<>(stress); | |
for (int i = 2000; i < 3000; i++) { | |
node = node.dissoc(pool[i]); | |
stress.remove(pool[i]); | |
verify(i); | |
} | |
if (node.root() != null) { | |
throw new IllegalStateException("teardown not perfect: " + node.root()); | |
} | |
stress = stress2; | |
node = node2; | |
} | |
var random = new Random(); | |
for (int i = 0; i < 10000; i++) { | |
int j = pool[random.nextInt(5000)]; | |
int val = random.nextInt(); | |
node = node.assoc(j, val); | |
stress.put(j, val); | |
verify(i); | |
int k = pool[random.nextInt(5000)]; | |
node = node.dissoc(k); | |
stress.remove(k); | |
verify(i); | |
} | |
} catch (Exception exn) { | |
//System.out.println(stress); | |
//System.out.println(node.root()); | |
throw exn; | |
} | |
} | |
public static void main(String[] args) { | |
new HamtTest().doMain(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment