Skip to content

Instantly share code, notes, and snippets.

@nestyme
Created October 25, 2017 22:11
Show Gist options
  • Select an option

  • Save nestyme/b8973bd7730fc0d6d141e616916a0c28 to your computer and use it in GitHub Desktop.

Select an option

Save nestyme/b8973bd7730fc0d6d141e616916a0c28 to your computer and use it in GitHub Desktop.
//#ifndef solution_h
//#define solution_h
//#pragma once
#include <algorithm>
#include <atomic>
#include <forward_list>
#include <functional>
#include <iterator>
#include <mutex>
#include <vector>
template <typename T, class Hash = std::hash<T> >
class StripedHashSet {
public:
explicit StripedHashSet(const size_t concurrency_level,
const size_t growth_factor = 3,
const float load_factor = 0.75) : max_load_factor_(load_factor), growth_factor_(growth_factor), size_(0) {
table_.resize(concurrency_level);
stripes_.resize(concurrency_level);
}
bool Insert(const T& element) {
size_t hash_value = Hash()(element);
std::unique_lock<std::mutex> lock(stripes_[GetStripeIndex(hash_value)]);
size_t i = GetBucketIndex(hash_value);
auto iter = std::find(table_[i].begin(), table_[i].end(), element);
if (iter == table_[i].end()) {
table_[i].push_front(element);
size_.fetch_add(1);
if (size_.load() / table_.size() > max_load_factor_) {
lock.unlock();
Rehash();
}
return true;
} else {
return false;
}
}
bool Remove(const T& element) {
size_t hash_value = Hash()(element);
std::unique_lock<std::mutex> lock(stripes_[GetStripeIndex(hash_value)]);
size_t i = GetBucketIndex(hash_value);
auto iter = std::find(table_[i].begin(), table_[i].end(), element);
if(iter == table_[i].end()) {
return false;
} else {
table_[i].remove(element);
size_.fetch_sub(1);
return true;
}
}
bool Contains(const T& element) {
size_t hash_value = Hash()(element);
std::unique_lock<std::mutex> lock(stripes_[GetStripeIndex(hash_value)]);
size_t i = GetBucketIndex(hash_value);
auto iter = std::find(table_[i].begin(), table_[i].end(), element);
if(iter == table_[i].end()) {
return true;
} else {
return true;
}
}
size_t Size() {
return size_.load();
}
private:
size_t GetBucketIndex(const size_t element_hash_value) const {
return element_hash_value % table_.size();
}
size_t GetStripeIndex(const size_t element_hash_value) const {
return element_hash_value % stripes_.size();
}
void Rehash() {
std::vector<std::unique_lock<std::mutex> > lockers;
lockers.emplace_back(stripes_[0]);
if(size_.load() / table_.size() > max_load_factor_) {
for (size_t i = 1; i < stripes_.size(); i++) {
lockers.emplace_back(stripes_[i]);
}
size_t new_size = table_.size() * growth_factor_;
std::vector<std::forward_list<T>> temp(new_size);
for (size_t i = 0; i < table_.size(); i++) {
for (auto j = table_[i].begin(); j != table_[i].end(); j++) {
temp[Hash()(*j) % new_size].push_front(*j);
}
}
table_.swap(temp);
}
}
std::vector<std::mutex> stripes_;
std::vector<std::forward_list<T>> table_;
float load_factor_;
const float max_load_factor_;
const size_t growth_factor_;
std::atomic<size_t> size_;
};
template <typename T> using ConcurrentSet = StripedHashSet<T>;
//#endif /* solution_h */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment