Created
October 21, 2018 18:36
-
-
Save felix-d/69878611b7e04159eb20be33d6706da2 to your computer and use it in GitHub Desktop.
LRU Cache in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(box_syntax)] | |
#![feature(box_into_raw_non_null)] | |
#![feature(nll)] | |
use std::collections::HashMap; | |
use std::hash::Hash; | |
use std::ptr::NonNull; | |
use std::cell::RefCell; | |
#[derive(Debug)] | |
struct Node<K, V> { | |
key: K, | |
value: V, | |
prev: Option<NonNull<Node<K, V>>>, | |
next: Option<NonNull<Node<K, V>>>, | |
} | |
impl<K, V> Node<K, V> { | |
fn new(key: K, value: V) -> Box<Self> { | |
box Node { | |
key, | |
value, | |
prev: None, | |
next: None, | |
} | |
} | |
fn into_non_null(&mut self) -> NonNull<Node<K, V>> { | |
NonNull::new(&mut*self).unwrap() | |
} | |
// ptr must point to a valid Node<K, V>. | |
unsafe fn from_non_null(ptr: NonNull<Node<K, V>>) -> Box<Self> { | |
Box::from_raw(ptr.as_ptr()) | |
} | |
} | |
#[derive(Debug)] | |
pub struct LRUCache<K, V> | |
where | |
K: Hash + Eq + Clone, | |
{ | |
map: RefCell<HashMap<K, NonNull<Node<K, V>>>>, | |
head: RefCell<Option<NonNull<Node<K, V>>>>, | |
tail: RefCell<Option<NonNull<Node<K, V>>>>, | |
max_size: usize, | |
} | |
impl<K, V> LRUCache<K, V> | |
where | |
K: Hash + Eq + Clone, | |
{ | |
pub fn new(max_size: usize) -> Self { | |
LRUCache { | |
max_size, | |
map: RefCell::new(HashMap::new()), | |
head: RefCell::new(None), | |
tail: RefCell::new(None), | |
} | |
} | |
pub fn write(&mut self, key: K, value: V) { | |
self.delete(&key); | |
let mut node = Node::new(key.clone(), value); | |
self.bump(&mut node); | |
if self.len() >= self.max_size { | |
self.pop(); | |
} | |
self.map.borrow_mut().insert(key, Box::into_raw_non_null(node)); | |
} | |
pub fn read(&self, key: &K) -> Option<&V> { | |
let node = self.map.borrow_mut().get_mut(key).map(|node| node.as_ptr())?; | |
let node = unsafe { &mut*node }; | |
self.unlink(node); | |
self.bump(node); | |
Some(&node.value) | |
} | |
pub fn len(&self) -> usize { | |
self.map.borrow().len() | |
} | |
pub fn delete(&mut self, key: &K) -> Option<V> { | |
let node_ptr = self.map.borrow_mut().remove(&key)?; | |
let mut node = unsafe { Node::from_non_null(node_ptr) }; | |
self.unlink(&mut node); | |
if *self.head.borrow() == Some(node_ptr) { | |
self.unshift(); | |
} | |
if *self.tail.borrow() == Some(node_ptr) { | |
self.dequeue(); | |
} | |
Some(node.value) | |
} | |
pub fn clear(&self) { | |
*self.head.borrow_mut() = None; | |
*self.tail.borrow_mut() = None; | |
self.map.borrow_mut().drain().for_each(|(_, node)| { | |
unsafe { Node::from_non_null(node) }; | |
}); | |
} | |
fn bump(&self, node: &mut Node<K, V>) { | |
let node_ptr = node.into_non_null(); | |
if *self.head.borrow() != Some(node_ptr) { | |
self.make_head(node); | |
} | |
if self.tail.borrow().is_none() { | |
*self.tail.borrow_mut() = Some(node_ptr); | |
} else if *self.tail.borrow() == Some(node_ptr) && self.len() > 1 { | |
self.dequeue(); | |
} | |
} | |
fn pop(&self) { | |
if let Some(node) = self.dequeue() { | |
let node = unsafe { Node::from_non_null(node) }; | |
self.map.borrow_mut().remove(&node.key); | |
} | |
} | |
fn unshift(&self) -> Option<()> { | |
let head = (*self.head.borrow())?; | |
let head = unsafe { head.as_ref() }; | |
if let Some(mut prev) = head.prev { | |
let prev = unsafe { prev.as_mut() }; | |
prev.next = None; | |
} | |
*self.head.borrow_mut() = head.prev; | |
Some(()) | |
} | |
fn dequeue(&self) -> Option<NonNull<Node<K, V>>> { | |
let tail_ptr = (*self.tail.borrow_mut())?; | |
let tail = unsafe { tail_ptr.as_ref() }; | |
if let Some(mut next) = tail.next { | |
let next = unsafe { next.as_mut() }; | |
next.prev = None; | |
} | |
*self.tail.borrow_mut() = tail.next; | |
Some(tail_ptr) | |
} | |
fn make_head(&self, node: &mut Node<K, V>) { | |
node.prev = *self.head.borrow(); | |
let node_ptr = node.into_non_null(); | |
if let Some(mut head) = *self.head.borrow() { | |
let head = unsafe { head.as_mut() }; | |
head.next = Some(node_ptr); | |
} | |
*self.head.borrow_mut() = Some(node_ptr); | |
} | |
fn unlink(&self, node: &Node<K, V>) { | |
if let Some(mut prev) = node.prev { | |
let prev = unsafe { prev.as_mut() }; | |
prev.next = node.next; | |
} | |
if let Some(mut next) = node.next { | |
let next = unsafe { next.as_mut() }; | |
next.prev = node.prev; | |
} | |
} | |
} | |
impl<K, V> Drop for LRUCache<K, V> | |
where K: Eq + Hash + Clone | |
{ | |
fn drop(&mut self) { | |
self.clear(); | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::LRUCache; | |
#[test] | |
fn sets_max_size() { | |
let cache: LRUCache<(), ()> = LRUCache::new(10); | |
assert_eq!(10, cache.max_size); | |
} | |
#[test] | |
fn reads_are_consistent() { | |
let mut cache = LRUCache::new(2); | |
cache.write("foo", "bar"); | |
assert_eq!(Some(&"bar"), cache.read(&"foo")); | |
assert_eq!(Some(&"bar"), cache.read(&"foo")); | |
} | |
#[test] | |
fn read_returns_none_when_key_is_not_found() { | |
let cache: LRUCache<&str, ()> = LRUCache::new(2); | |
assert_eq!(None, cache.read(&"hello")) | |
} | |
#[test] | |
fn cache_has_max_length() { | |
let mut cache: LRUCache<&str, &str> = LRUCache::new(2); | |
cache.write("foo", "bar"); | |
cache.write("fooo", "baar"); | |
cache.write("foooo", "baaar"); | |
assert_eq!(2, cache.len()); | |
assert_eq!(None, cache.read(&"foo")); | |
assert_eq!(Some(&"baar"), cache.read(&"fooo")); | |
assert_eq!(Some(&"baaar"), cache.read(&"foooo")); | |
} | |
#[test] | |
fn read_bumps_the_value() { | |
let mut cache: LRUCache<&str, &str> = LRUCache::new(2); | |
cache.write("foo", "bar"); | |
cache.write("fooo", "baar"); | |
cache.read(&"foo"); | |
cache.write("foooo", "baaar"); | |
assert_eq!(2, cache.len()); | |
assert_eq!(Some(&"bar"), cache.read(&"foo")); | |
assert_eq!(None, cache.read(&mut "fooo")); | |
assert_eq!(Some(&"baaar"), cache.read(&"foooo")); | |
} | |
#[test] | |
fn delete_removes_the_value() { | |
let mut cache: LRUCache<&str, &str> = LRUCache::new(2); | |
cache.write("foo", "bar"); | |
cache.delete(&"foo"); | |
assert_eq!(None, cache.read(&"fooo")); | |
} | |
#[test] | |
fn len_returns_the_length() { | |
let mut cache: LRUCache<&str, &str> = LRUCache::new(2); | |
cache.write("foo", "bar"); | |
assert_eq!(1, cache.len()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment