Last active
July 3, 2022 12:59
-
-
Save glaebhoerl/d62d2b19365ae0d7c29102d0a5a6ab03 to your computer and use it in GitHub Desktop.
Rust hash table with efficient support for nested scopes (save/restore)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Based on idea: https://twitter.com/pkhuong/status/1287510400372748290 | |
use hashbrown::raw::RawTable; | |
pub struct ScopeMap<K, V> { | |
last_scope_id: ScopeId, | |
scopes: Vec<ScopeId>, // values are zeroed instead of popped to save a check in get() / is_fresh() | |
current_scope: ScopeDepth, // index of innermost valid scope | |
values: RawTable<Entry<K, V>>, | |
shadowed: Vec<Shadowed<K, V>>, | |
hash_builder: hashbrown::hash_map::DefaultHashBuilder | |
} | |
#[derive(Copy, Clone, PartialEq, Eq)] | |
struct ScopeId(u32); | |
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] | |
struct ScopeDepth(u16); | |
struct Entry<K, V> { | |
key: K, | |
value: V, | |
scope_id: ScopeId, | |
scope_depth: ScopeDepth | |
} | |
struct Shadowed<K, V> { | |
key: K, | |
value: V, | |
original_scope_depth: ScopeDepth, | |
shadowing_bucket_index: usize // raw table index where it will need to be put back | |
} | |
impl<K, V> ScopeMap<K, V> { | |
fn scope_id(&self, depth: ScopeDepth) -> ScopeId { self.scopes[depth.0 as usize] } | |
fn is_fresh(&self, entry: &Entry<K, V>) -> bool { self.scope_id(entry.scope_depth) == entry.scope_id } | |
pub fn new() -> ScopeMap<K, V> { | |
ScopeMap { | |
last_scope_id: ScopeId(1), | |
scopes: vec![ScopeId(1)], | |
current_scope: ScopeDepth(0), | |
values: RawTable::new(), | |
shadowed: Vec::new(), | |
hash_builder: Default::default() | |
} | |
} | |
// Saves the current state of the map. | |
pub fn enter_scope(&mut self) { | |
assert!(self.current_scope.0 < u16::MAX); | |
self.last_scope_id.0 += 1; | |
self.current_scope.0 += 1; | |
if self.scopes.len() == self.current_scope.0 as usize { | |
self.scopes.push(self.last_scope_id); | |
} else { | |
self.scopes[self.current_scope.0 as usize] = self.last_scope_id; | |
} | |
} | |
// Restores the map to a saved state (LIFO). | |
pub fn exit_scope(&mut self) { | |
assert!(self.current_scope.0 > 0); | |
while self.shadowed.last().map(|shadowed| unsafe { self.values.bucket(shadowed.shadowing_bucket_index).as_ref().scope_depth }) == Some(self.current_scope) { | |
let Shadowed { key, value, original_scope_depth, shadowing_bucket_index } = self.shadowed.pop().unwrap(); | |
*unsafe { self.values.bucket(shadowing_bucket_index).as_mut() } = Entry { | |
key, | |
value, | |
scope_depth: original_scope_depth, | |
scope_id: self.scope_id(original_scope_depth) | |
}; | |
} | |
self.scopes[self.current_scope.0 as usize] = ScopeId(0); | |
self.current_scope.0 -= 1; | |
} | |
pub fn current_scope(&mut self) -> Scope<'_, K, V> { | |
Scope { map: self, scopes_to_pop: 0 } | |
} | |
pub fn new_scope(&mut self) -> Scope<'_, K, V> { | |
self.enter_scope(); | |
return Scope { map: self, scopes_to_pop: 1 }; | |
} | |
} | |
fn make_hash<K: std::hash::Hash>(hash_builder: &hashbrown::hash_map::DefaultHashBuilder, key: &K) -> u64 { | |
use std::hash::{Hasher, BuildHasher}; | |
let mut state = hash_builder.build_hasher(); | |
key.hash(&mut state); | |
return state.finish(); | |
} | |
impl<K: Eq + std::hash::Hash, V> ScopeMap<K, V> { | |
fn make_hash(&self, key: &K) -> u64 { make_hash(&self.hash_builder, key) } | |
pub fn get(&self, key: &K) -> Option<&V> { | |
let entry = unsafe { self.values.find(self.make_hash(key), |e| *key == e.key)?.as_ref() }; | |
if self.is_fresh(entry) { | |
return Some(&entry.value); | |
} else { | |
return None; | |
} | |
} | |
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { | |
let bucket = self.values.find(self.make_hash(key), |e| *key == e.key)?; | |
if self.is_fresh(unsafe { bucket.as_ref() }) { | |
return Some(unsafe { &mut bucket.as_mut().value }); | |
} else { | |
return None; | |
} | |
} | |
fn check_shadowed(&mut self, entry: Entry<K, V>, shadowing_bucket: hashbrown::raw::Bucket<Entry<K, V>>) -> Option<V> { | |
if self.is_fresh(&entry) { | |
if entry.scope_depth == self.current_scope { | |
return Some(entry.value); | |
} else { | |
debug_assert!(entry.scope_depth < self.current_scope); | |
self.shadowed.push(Shadowed { | |
key: entry.key, | |
value: entry.value, | |
original_scope_depth: entry.scope_depth, | |
// convert bucket to index just for better odds of remaining in stacked borrows's good graces... | |
shadowing_bucket_index: unsafe { self.values.bucket_index(&shadowing_bucket) } | |
}); | |
} | |
} | |
return None; | |
} | |
// In the case of shadowing *within the same scope*, the shadowed value is returned. | |
pub fn insert(&mut self, key: K, value: V) -> Option<V> { | |
let hash = self.make_hash(&key); | |
let new_entry = Entry { | |
key, | |
value, | |
scope_id: self.scope_id(self.current_scope), | |
scope_depth: self.current_scope | |
}; | |
let mut hits = unsafe { self.values.iter_hash(hash) }; | |
while let Some(bucket) = hits.next() { | |
let existing_entry = unsafe { bucket.as_mut() }; | |
if new_entry.key == existing_entry.key { | |
// Case 1: The first thing we find is an entry with a matching key. | |
// Replace it, then check whether it was stale or if we're shadowing it. | |
return self.check_shadowed(std::mem::replace(existing_entry, new_entry), bucket); | |
} else if !self.is_fresh(existing_entry) { | |
// Case 2: The first thing we find is a stale entry with a colliding hash. Overwrite it... | |
*existing_entry = new_entry; | |
// ...but we still need to check the remaining hits to see if a matching key also exists, which we then need to remove! | |
while let Some(other_bucket) = hits.next() { | |
// (At this point `existing_entry` is the one we've just inserted.) | |
if existing_entry.key == unsafe { other_bucket.as_ref() }.key { | |
let old_entry = unsafe { self.values.remove(other_bucket) }; | |
// (the shadowing bucket is `bucket`, not `other_bucket`, which we just removed!) | |
return self.check_shadowed(old_entry, bucket); | |
} | |
} | |
return None; | |
} | |
} | |
// Case 3: The key doesn't exist, nor does a stale entry with a colliding hash. We have to insert a new one. | |
if self.values.len() == self.values.capacity() { | |
// But first, if the table is full, sweep out any stale entries before growing it. | |
self.sweep_and_grow(); | |
// Note that since we're checking this ourselves, we lose out on the small additional optimization | |
// `hashbrown` does where it avoids growing the map when it's overwriting a tombstone. | |
// This is incidentally very similar to our own optimization when we overwrite a stale entry, above. | |
// (It provides no external API to recover this behavior.) | |
} | |
// Since we already handled the out-of-space condition ourselves just above, `insert()` itself will definitely not need to grow. | |
self.values.insert_no_grow(hash, new_entry); | |
return None; | |
} | |
#[cold] #[inline(never)] | |
fn sweep_and_grow(&mut self) { | |
// Sweep out stale entries | |
unsafe { | |
for item in self.values.iter() { | |
if !self.is_fresh(item.as_ref()) { | |
self.values.erase(item); | |
} | |
} | |
// Conceivably we could keep track of how many stale vs. non-stale entries are actually in the map, | |
// and use that to determine whether sweeping them is going to be worthwhile. | |
// But that would burden the "mutator", whereas "collection" is already rare & costly, so I lean against it. | |
} | |
// Unless that succeeded in freeing up more than half the map, grow it nonetheless to avoid thrashing | |
if self.values.len() >= self.values.capacity() / 2 { | |
let hash_builder = &self.hash_builder; | |
self.values.reserve(self.values.capacity() + 1, |e| make_hash(hash_builder, &e.key)); | |
// That invalidated the direct bucket indices we've been storing in shadowed entries, so we need to fix them manually: | |
for shadowed in &mut self.shadowed { | |
shadowed.shadowing_bucket_index = unsafe { self.values.bucket_index(&self.values.find(make_hash(hash_builder, &shadowed.key), |e| e.key == shadowed.key).unwrap()) }; | |
// Maybe it would be worthwhile to cache the hash in `Shadowed` to avoid recalculating it, | |
// but there's a similar tradeoff w.r.t. caching it in the table itself which `hashbrown` doesn't, | |
// so then again maybe not. | |
} | |
} | |
// `hashbrown` itself also does a very similar thing inside `reserve()` for clearing out | |
// tombstones ("DELETED" entries) and then deciding whether to grow or not. | |
// Maybe it would be better if the two checks could be combined, but alas, there is no API. | |
// Probably it doesn't matter very much though. | |
} | |
} | |
// What happens in a Scope, stays in the Scope. But it can see what's outside. | |
pub struct Scope<'a, K, V> { | |
map: &'a mut ScopeMap<K, V>, | |
scopes_to_pop: usize | |
} | |
impl<'a, K, V> Drop for Scope<'a, K, V> { | |
fn drop(&mut self) { | |
for _ in 0..self.scopes_to_pop { | |
self.map.exit_scope(); | |
} | |
} | |
} | |
impl<'a, K, V> Scope<'a, K, V> { | |
pub fn new_scope(&mut self) -> Scope<'_, K, V> { | |
self.map.enter_scope(); | |
return Scope { map: self.map, scopes_to_pop: 1 }; | |
} | |
pub fn borrow(&mut self) -> Scope<'_, K, V> { | |
Scope { map: self.map, scopes_to_pop: 0 } | |
} | |
pub fn enter_scope(&mut self) { // this is fine, it's only `exit_scope()` which would violate expectations. but is this useful?? | |
self.map.enter_scope(); | |
self.scopes_to_pop += 1; | |
} | |
} | |
impl<'a, K: Eq + std::hash::Hash, V> Scope<'a, K, V> { | |
// In the case of shadowing *within the same scope*, the shadowed value is returned. | |
pub fn insert(&mut self, key: K, value: V) -> Option<V> { | |
self.map.insert(key, value) | |
} | |
pub fn get(&self, key: &K) -> Option<&V> { | |
self.map.get(key) | |
} | |
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { | |
self.map.get_mut(key) | |
} | |
} | |
#[cfg(test)] | |
mod test { | |
use super::ScopeMap; | |
use hashbrown::HashMap; | |
use std::hash::Hash; | |
struct Simple<K, V> { scopes: Vec<HashMap<K, V>> } | |
impl<K: Hash + Eq, V> Simple<K, V> { | |
fn new() -> Simple<K, V> { Simple { scopes: vec![HashMap::new()] } } | |
fn enter_scope(&mut self) { self.scopes.push(HashMap::new()) } | |
fn exit_scope(&mut self) { self.scopes.pop(); } | |
fn insert(&mut self, k: K, v: V) -> Option<V> { self.scopes.last_mut().unwrap().insert(k, v) } | |
fn get(&self, k: &K) -> Option<&V> { self.scopes.iter().rev().find_map(|m| m.get(k)) } | |
} | |
struct Test<K, V> { | |
control: Simple<K, V>, | |
subject: ScopeMap<K, V> | |
} | |
impl<K: Hash + Eq + Clone, V: Eq + Clone> Test<K, V> { | |
fn new() -> Test<K, V> { | |
Test { | |
control: Simple::new(), | |
subject: ScopeMap::new() | |
} | |
} | |
fn enter_scope(&mut self) { | |
self.control.enter_scope(); | |
self.subject.enter_scope(); | |
} | |
fn exit_scope(&mut self) { | |
self.control.exit_scope(); | |
self.subject.exit_scope(); | |
} | |
fn insert(&mut self, k: K, v: V) -> Option<V> { | |
let a = self.control.insert(k.clone(), v.clone()); | |
let b = self.subject.insert(k, v); | |
assert!(a == b); | |
return b; | |
} | |
fn get(&self, k: &K) -> Option<&V> { | |
let a = self.control.get(k); | |
let b = self.subject.get(k); | |
assert!(a == b); | |
return b; | |
} | |
} | |
fn test<K, V>(num_ops: u64) | |
where K: Hash + Eq + Clone + From<u8> + std::fmt::Display, | |
V: Eq + Clone + From<u64> + std::fmt::Display | |
{ | |
let mut test = Test::<K, V>::new(); | |
let mut scopes = 0; | |
let mut random = 0; | |
let hash_builder = Default::default(); | |
let keys = (0..=255u8).map(Into::into).collect::<Vec<K>>(); // no impl for [K; 256] :[ | |
for _ in 0..num_ops { | |
random = super::make_hash(&hash_builder, &random); | |
if random % 20 == 0 { | |
if scopes > 0 { | |
scopes -= 1; | |
test.exit_scope(); | |
} else { | |
scopes += 1; | |
test.enter_scope(); | |
} | |
} else if random % 10 == 0 { | |
if scopes < 65535 { | |
scopes += 1; | |
test.enter_scope(); | |
} else { | |
scopes -= 1; | |
test.exit_scope(); | |
} | |
} else if random % 256 == 0 { | |
for k in &keys { | |
test.get(k); | |
} | |
} else { | |
test.insert((random as u8).into(), random.into()); | |
} | |
} | |
assert!(scopes == test.subject.current_scope.0); | |
println!(); | |
println!( | |
"scopes: {}, values: {}, entries: {}, shadowed: {}", | |
scopes, | |
test.control.scopes.iter().map(|m| m.len()).sum::<usize>(), | |
test.subject.values.len(), | |
test.subject.shadowed.len() | |
); | |
for k in &keys { | |
if let Some(v) = test.get(k) { | |
print!("{}=>{} ", k, v); | |
} | |
} | |
println!(); | |
while scopes > 0 { | |
scopes -= 1; | |
test.exit_scope(); | |
for k in &keys { | |
test.get(k); | |
} | |
} | |
} | |
#[test] | |
fn test_unboxed_unboxed() { | |
test::<u8, u64>( | |
if cfg!(miri) { 256 } else { 1048576 } | |
); | |
} | |
#[test] | |
fn test_unboxed_boxed() { | |
test::<u8, Box<u64>>( | |
if cfg!(miri) { 256 } else { 1048576 } | |
); | |
} | |
#[test] | |
fn test_boxed_unboxed() { | |
test::<Box<u8>, u64>( | |
if cfg!(miri) { 256 } else { 1048576 } | |
); | |
} | |
#[test] | |
fn test_boxed_boxed() { | |
test::<Box<u8>, Box<u64>>( | |
if cfg!(miri) { 256 } else { 1048576 } | |
); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment