Skip to content

Instantly share code, notes, and snippets.

@grahamwren
Created April 1, 2020 21:27
Show Gist options
  • Save grahamwren/360030a38dd4fb27bda0ec0ed1181d09 to your computer and use it in GitHub Desktop.
Save grahamwren/360030a38dd4fb27bda0ec0ed1181d09 to your computer and use it in GitHub Desktop.
#pragma once
#include <algorithm>
#include <list>
#include <memory>
#include <unordered_map>
using namespace std;
template <typename K, typename V, int cache_size = 32> class LRUCache {
static_assert(cache_size > 0, "Cache must be at least one element");
public:
typedef K key_t;
typedef V value_t;
typedef function<unique_ptr<value_t>(const key_t &)> fetch_fun_t;
/**
* fetch value via the cache. Returns a weak_ptr because the cache could evict
* and delete the value at any time.
*
* if key already in the cache
* - fn is not executed, weak_ptr to the value is returned out of the cache,
* key is promoted to first in the cache
*
* else
* - fn is executed to retrieve the required value, resultant value is stored
* in the cache, key is set as the first in the cache, and a weak_ptr
* to it is returned
*/
weak_ptr<value_t> fetch(const key_t &k, fetch_fun_t fn) {
if (has(k))
return get_from_cache(k);
else {
return add_to_cache(k, fn(k));
}
}
private:
mutable list<reference_wrapper<const key_t>> keys;
unordered_map<key_t, shared_ptr<value_t>> cache;
/**
* returns whether the cache contains the requested key
*/
bool has(const key_t &k) const { return cache.find(k) != cache.end(); }
/**
* get a weak_ptr to a value stored in the cache, assumes key is in cache
*/
weak_ptr<value_t> get_from_cache(const key_t &key) {
auto it = find_if(keys.begin(), keys.end(),
[&, key](auto k) { return k.get() == key; });
assert(it != keys.end()); // assert found value
if (it != keys.begin()) {
// move key to front of list
const key_t &move_key = it->get();
keys.erase(it);
keys.push_front(move_key);
}
return weak_ptr<value_t>(cache.at(key));
}
/**
* adds the passed value to the cache as a shared_ptr and returns a weak_ptr
* to that object
*/
weak_ptr<value_t> add_to_cache(const key_t &key, shared_ptr<value_t> val) {
assert(!has(key)); // assert does not yet have key
assert(keys.size() == cache.size()); // assert sizes are in sync
if (keys.size() >= cache_size) {
/* evict */
const key_t &key_to_evict = keys.back();
keys.pop_back();
cache.erase(key_to_evict);
}
const auto &[it, res] = cache.emplace(key, val); // copies key
assert(res); // assert was created
keys.push_front(it->first); // push reference to key created in cache
return weak_ptr<value_t>(it->second);
}
};
#pragma once
#include "lru_cache.h"
#include <cstring>
#include <string>
using namespace std;
static constexpr int CSIZ = 128;
TEST(TestLRUCache, test_fetch) {
LRUCache<int, string, CSIZ> scache;
int alloc_count = 0;
int last_alloc_count = alloc_count; // 0
auto i_to_s = [&](const int k) {
alloc_count++; // captured
char buf[10];
sprintf(buf, "%d", k);
return unique_ptr<string>(new string(buf));
};
char buf[BUFSIZ];
/* fill cache up to cache_size */
for (int i = 0; i < CSIZ; i++) {
auto s_ptr = scache.fetch(i, i_to_s).lock();
sprintf(buf, "%d", i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect as many allocations as are in cache */
EXPECT_EQ(alloc_count, last_alloc_count + CSIZ);
last_alloc_count = alloc_count;
/* cache order now: [ CSIZ, CSIZ - 1, ... , 0 ] */
/* ensure alloc_count does not change for same keys within cache_size */
for (int i = 0; i < CSIZ; i++) {
auto s_ptr = scache.fetch(i, i_to_s).lock();
sprintf(buf, "%d", i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect alloc_count to still be the same */
EXPECT_EQ(alloc_count, last_alloc_count);
/* cache order unchanged */
/* ensure alloc_count increases when past cache_size */
for (int i = CSIZ; i < CSIZ + 2; i++) {
auto s_ptr = scache.fetch(i, i_to_s).lock();
sprintf(buf, "%d", i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect an extra two allocations */
EXPECT_EQ(alloc_count, last_alloc_count + 2);
last_alloc_count = alloc_count;
/* cache order now: [ CSIZ + 1, CSIZ, ... , 2 ] */
/* ensure first two elements were evicted from cache */
for (int i = 0; i < 2; i++) {
auto s_ptr = scache.fetch(i, i_to_s).lock();
sprintf(buf, "%d", i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect two more allocations to account for evicted elements */
EXPECT_EQ(alloc_count, last_alloc_count + 2);
last_alloc_count = alloc_count;
/* cache order now: [ 2, 1, CSIZ + 1, CSIZ, ... , 4 ] */
/* promote two earlier elements to front of cache */
for (int i = CSIZ; i < CSIZ + 2; i++) {
auto s_ptr = scache.fetch(i, i_to_s).lock();
sprintf(buf, "%d", i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect alloc_count unchanged */
EXPECT_EQ(alloc_count, last_alloc_count);
/* cache order now: [ CSIZ + 1, CSIZ, 2, 1, ... , 4 ] */
const int NS = CSIZ * 2; // new start, guaranteed greater than CSIZ
/* fill cache with new values except last two */
for (int i = 0; i < CSIZ - 2; i++) {
auto s_ptr = scache.fetch(NS + i, i_to_s).lock();
sprintf(buf, "%d", NS + i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect CSIZ - 2 more allocations */
EXPECT_EQ(alloc_count, last_alloc_count + (CSIZ - 2));
last_alloc_count = alloc_count;
/* cache order now: [ NS + (CSIZ - 2), ... , CSIZ + 1, CSIZ ] */
/* request last two values which should still be in the cache */
for (int i = CSIZ; i < CSIZ + 2; i++) {
auto s_ptr = scache.fetch(i, i_to_s).lock();
sprintf(buf, "%d", i);
EXPECT_STREQ(s_ptr->c_str(), buf);
}
/* expect alloc_count to be unchanged */
EXPECT_EQ(alloc_count, last_alloc_count);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment