Created
April 1, 2020 21:27
-
-
Save grahamwren/360030a38dd4fb27bda0ec0ed1181d09 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <algorithm> | |
#include <list> | |
#include <memory> | |
#include <unordered_map> | |
using namespace std; | |
template <typename K, typename V, int cache_size = 32> class LRUCache { | |
static_assert(cache_size > 0, "Cache must be at least one element"); | |
public: | |
typedef K key_t; | |
typedef V value_t; | |
typedef function<unique_ptr<value_t>(const key_t &)> fetch_fun_t; | |
/** | |
* fetch value via the cache. Returns a weak_ptr because the cache could evict | |
* and delete the value at any time. | |
* | |
* if key already in the cache | |
* - fn is not executed, weak_ptr to the value is returned out of the cache, | |
* key is promoted to first in the cache | |
* | |
* else | |
* - fn is executed to retrieve the required value, resultant value is stored | |
* in the cache, key is set as the first in the cache, and a weak_ptr | |
* to it is returned | |
*/ | |
weak_ptr<value_t> fetch(const key_t &k, fetch_fun_t fn) { | |
if (has(k)) | |
return get_from_cache(k); | |
else { | |
return add_to_cache(k, fn(k)); | |
} | |
} | |
private: | |
mutable list<reference_wrapper<const key_t>> keys; | |
unordered_map<key_t, shared_ptr<value_t>> cache; | |
/** | |
* returns whether the cache contains the requested key | |
*/ | |
bool has(const key_t &k) const { return cache.find(k) != cache.end(); } | |
/** | |
* get a weak_ptr to a value stored in the cache, assumes key is in cache | |
*/ | |
weak_ptr<value_t> get_from_cache(const key_t &key) { | |
auto it = find_if(keys.begin(), keys.end(), | |
[&, key](auto k) { return k.get() == key; }); | |
assert(it != keys.end()); // assert found value | |
if (it != keys.begin()) { | |
// move key to front of list | |
const key_t &move_key = it->get(); | |
keys.erase(it); | |
keys.push_front(move_key); | |
} | |
return weak_ptr<value_t>(cache.at(key)); | |
} | |
/** | |
* adds the passed value to the cache as a shared_ptr and returns a weak_ptr | |
* to that object | |
*/ | |
weak_ptr<value_t> add_to_cache(const key_t &key, shared_ptr<value_t> val) { | |
assert(!has(key)); // assert does not yet have key | |
assert(keys.size() == cache.size()); // assert sizes are in sync | |
if (keys.size() >= cache_size) { | |
/* evict */ | |
const key_t &key_to_evict = keys.back(); | |
keys.pop_back(); | |
cache.erase(key_to_evict); | |
} | |
const auto &[it, res] = cache.emplace(key, val); // copies key | |
assert(res); // assert was created | |
keys.push_front(it->first); // push reference to key created in cache | |
return weak_ptr<value_t>(it->second); | |
} | |
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include "lru_cache.h" | |
#include <cstring> | |
#include <string> | |
using namespace std; | |
static constexpr int CSIZ = 128; | |
TEST(TestLRUCache, test_fetch) { | |
LRUCache<int, string, CSIZ> scache; | |
int alloc_count = 0; | |
int last_alloc_count = alloc_count; // 0 | |
auto i_to_s = [&](const int k) { | |
alloc_count++; // captured | |
char buf[10]; | |
sprintf(buf, "%d", k); | |
return unique_ptr<string>(new string(buf)); | |
}; | |
char buf[BUFSIZ]; | |
/* fill cache up to cache_size */ | |
for (int i = 0; i < CSIZ; i++) { | |
auto s_ptr = scache.fetch(i, i_to_s).lock(); | |
sprintf(buf, "%d", i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect as many allocations as are in cache */ | |
EXPECT_EQ(alloc_count, last_alloc_count + CSIZ); | |
last_alloc_count = alloc_count; | |
/* cache order now: [ CSIZ, CSIZ - 1, ... , 0 ] */ | |
/* ensure alloc_count does not change for same keys within cache_size */ | |
for (int i = 0; i < CSIZ; i++) { | |
auto s_ptr = scache.fetch(i, i_to_s).lock(); | |
sprintf(buf, "%d", i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect alloc_count to still be the same */ | |
EXPECT_EQ(alloc_count, last_alloc_count); | |
/* cache order unchanged */ | |
/* ensure alloc_count increases when past cache_size */ | |
for (int i = CSIZ; i < CSIZ + 2; i++) { | |
auto s_ptr = scache.fetch(i, i_to_s).lock(); | |
sprintf(buf, "%d", i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect an extra two allocations */ | |
EXPECT_EQ(alloc_count, last_alloc_count + 2); | |
last_alloc_count = alloc_count; | |
/* cache order now: [ CSIZ + 1, CSIZ, ... , 2 ] */ | |
/* ensure first two elements were evicted from cache */ | |
for (int i = 0; i < 2; i++) { | |
auto s_ptr = scache.fetch(i, i_to_s).lock(); | |
sprintf(buf, "%d", i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect two more allocations to account for evicted elements */ | |
EXPECT_EQ(alloc_count, last_alloc_count + 2); | |
last_alloc_count = alloc_count; | |
/* cache order now: [ 2, 1, CSIZ + 1, CSIZ, ... , 4 ] */ | |
/* promote two earlier elements to front of cache */ | |
for (int i = CSIZ; i < CSIZ + 2; i++) { | |
auto s_ptr = scache.fetch(i, i_to_s).lock(); | |
sprintf(buf, "%d", i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect alloc_count unchanged */ | |
EXPECT_EQ(alloc_count, last_alloc_count); | |
/* cache order now: [ CSIZ + 1, CSIZ, 2, 1, ... , 4 ] */ | |
const int NS = CSIZ * 2; // new start, guaranteed greater than CSIZ | |
/* fill cache with new values except last two */ | |
for (int i = 0; i < CSIZ - 2; i++) { | |
auto s_ptr = scache.fetch(NS + i, i_to_s).lock(); | |
sprintf(buf, "%d", NS + i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect CSIZ - 2 more allocations */ | |
EXPECT_EQ(alloc_count, last_alloc_count + (CSIZ - 2)); | |
last_alloc_count = alloc_count; | |
/* cache order now: [ NS + (CSIZ - 2), ... , CSIZ + 1, CSIZ ] */ | |
/* request last two values which should still be in the cache */ | |
for (int i = CSIZ; i < CSIZ + 2; i++) { | |
auto s_ptr = scache.fetch(i, i_to_s).lock(); | |
sprintf(buf, "%d", i); | |
EXPECT_STREQ(s_ptr->c_str(), buf); | |
} | |
/* expect alloc_count to be unchanged */ | |
EXPECT_EQ(alloc_count, last_alloc_count); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment