Created
October 25, 2017 22:11
-
-
Save nestyme/b8973bd7730fc0d6d141e616916a0c28 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| //#ifndef solution_h | |
| //#define solution_h | |
| //#pragma once | |
| #include <algorithm> | |
| #include <atomic> | |
| #include <forward_list> | |
| #include <functional> | |
| #include <iterator> | |
| #include <mutex> | |
| #include <vector> | |
| template <typename T, class Hash = std::hash<T> > | |
| class StripedHashSet { | |
| public: | |
| explicit StripedHashSet(const size_t concurrency_level, | |
| const size_t growth_factor = 3, | |
| const float load_factor = 0.75) : max_load_factor_(load_factor), growth_factor_(growth_factor), size_(0) { | |
| table_.resize(concurrency_level); | |
| stripes_.resize(concurrency_level); | |
| } | |
| bool Insert(const T& element) { | |
| size_t hash_value = Hash()(element); | |
| std::unique_lock<std::mutex> lock(stripes_[GetStripeIndex(hash_value)]); | |
| size_t i = GetBucketIndex(hash_value); | |
| auto iter = std::find(table_[i].begin(), table_[i].end(), element); | |
| if (iter == table_[i].end()) { | |
| table_[i].push_front(element); | |
| size_.fetch_add(1); | |
| if (size_.load() / table_.size() > max_load_factor_) { | |
| lock.unlock(); | |
| Rehash(); | |
| } | |
| return true; | |
| } else { | |
| return false; | |
| } | |
| } | |
| bool Remove(const T& element) { | |
| size_t hash_value = Hash()(element); | |
| std::unique_lock<std::mutex> lock(stripes_[GetStripeIndex(hash_value)]); | |
| size_t i = GetBucketIndex(hash_value); | |
| auto iter = std::find(table_[i].begin(), table_[i].end(), element); | |
| if(iter == table_[i].end()) { | |
| return false; | |
| } else { | |
| table_[i].remove(element); | |
| size_.fetch_sub(1); | |
| return true; | |
| } | |
| } | |
| bool Contains(const T& element) { | |
| size_t hash_value = Hash()(element); | |
| std::unique_lock<std::mutex> lock(stripes_[GetStripeIndex(hash_value)]); | |
| size_t i = GetBucketIndex(hash_value); | |
| auto iter = std::find(table_[i].begin(), table_[i].end(), element); | |
| if(iter == table_[i].end()) { | |
| return true; | |
| } else { | |
| return true; | |
| } | |
| } | |
| size_t Size() { | |
| return size_.load(); | |
| } | |
| private: | |
| size_t GetBucketIndex(const size_t element_hash_value) const { | |
| return element_hash_value % table_.size(); | |
| } | |
| size_t GetStripeIndex(const size_t element_hash_value) const { | |
| return element_hash_value % stripes_.size(); | |
| } | |
| void Rehash() { | |
| std::vector<std::unique_lock<std::mutex> > lockers; | |
| lockers.emplace_back(stripes_[0]); | |
| if(size_.load() / table_.size() > max_load_factor_) { | |
| for (size_t i = 1; i < stripes_.size(); i++) { | |
| lockers.emplace_back(stripes_[i]); | |
| } | |
| size_t new_size = table_.size() * growth_factor_; | |
| std::vector<std::forward_list<T>> temp(new_size); | |
| for (size_t i = 0; i < table_.size(); i++) { | |
| for (auto j = table_[i].begin(); j != table_[i].end(); j++) { | |
| temp[Hash()(*j) % new_size].push_front(*j); | |
| } | |
| } | |
| table_.swap(temp); | |
| } | |
| } | |
| std::vector<std::mutex> stripes_; | |
| std::vector<std::forward_list<T>> table_; | |
| float load_factor_; | |
| const float max_load_factor_; | |
| const size_t growth_factor_; | |
| std::atomic<size_t> size_; | |
| }; | |
| template <typename T> using ConcurrentSet = StripedHashSet<T>; | |
| //#endif /* solution_h */ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment