Skip to content

Instantly share code, notes, and snippets.

@jacky860226
Created November 8, 2025 02:37
Show Gist options
  • Select an option

  • Save jacky860226/48e3ca4189d6c95690896a18e14b0b61 to your computer and use it in GitHub Desktop.

Select an option

Save jacky860226/48e3ca4189d6c95690896a18e14b0b61 to your computer and use it in GitHub Desktop.
Thread Local Storage
#pragma once
#include <atomic>
#include <cstdint>
#include <cstring>
#include <deque>
#include <functional>
#include <memory>
#include <thread>
namespace TLS::detail {
// 類型轉換工具函數:在不相關的指針類型之間進行轉換
// 這是處理底層記憶體操作時的最後手段,破壞了 C++ 的嚴格別名規則
template <typename T, typename U> inline T punned_cast(U *ptr) {
std::uintptr_t x = reinterpret_cast<std::uintptr_t>(ptr);
return reinterpret_cast<T>(x);
}
// 填充基類:當需要填充時繼承此類
template <class T, size_t S, size_t R> struct padded_base : T {
char pad[S - R]; // 添加填充字節以達到指定大小
};
// 填充基類的特化:當不需要填充時(R=0)繼承此類
template <class T, size_t S> struct padded_base<T, S, 0> : T {};
// 將類型 T 填充到快取行大小的倍數,避免 false sharing
// S: 目標大小(默認為硬體破壞性干擾大小,通常是快取行大小 64 字節)
// sizeof(T) % S: 計算需要填充的字節數
template <class T, size_t S = std::hardware_destructive_interference_size>
struct padded : padded_base<T, S, sizeof(T) % S> {};
template <typename T, std::size_t N = 1> class aligned_space {
// 按照 T 的對齊要求對齊的字節陣列,用於存儲 N 個 T 類型的物件
alignas(alignof(T)) std::uint8_t aligned_array[N * sizeof(T)];
public:
// 返回指向陣列開始位置的指針,用於構造物件
T *begin() const { return punned_cast<T *>(&aligned_array); }
// 返回指向陣列結束位置(最後一個元素的下一個位置)的指針
T *end() const { return begin() + N; }
};
template <typename U> struct tls_element {
aligned_space<U> my_space; // 對齊的記憶體空間,用於存放類型 U 的物件
bool is_built; // 標記物件是否已經被構造
// 構造函數:初始化為未構造狀態
tls_element() { is_built = false; }
// 返回指向存儲空間的指針(物件可能尚未構造)
U *value() { return my_space.begin(); }
// 標記物件為已構造並返回指針(用於構造完成後)
U *value_committed() {
is_built = true;
return my_space.begin();
}
// 解構函數:如果物件已構造,則手動調用其解構函數
~tls_element() {
if (is_built) {
my_space.begin()->~U();
is_built = false;
}
}
};
}; // namespace TLS::detail
template <typename T, typename Allocator = std::allocator<T>>
class ThreadLocalStorage {
// 使用執行緒 ID 作為雜湊表的鍵
using key_type = std::thread::id;
// 填充的執行緒本地元素,防止偽共享
using padded_element = TLS::detail::padded<TLS::detail::tls_element<T>>;
// 分配器特性類型,用於記憶體管理
using allocator_traits_type = std::allocator_traits<Allocator>;
// 重新綁定分配器以分配填充後的元素
using padded_allocator_type =
typename allocator_traits_type::template rebind_alloc<padded_element>;
// 重新綁定分配器以分配陣列元素
using array_allocator_type =
typename allocator_traits_type::template rebind_alloc<uintptr_t>;
// 內部容器類型,使用雙端佇列存儲填充的元素
using internal_collection_type =
std::deque<padded_element, padded_allocator_type>;
struct slot;
// 雜湊表陣列結構,使用開放尋址法解決衝突
struct array {
array *next; // 指向下一個陣列的指針(用於動態擴展)
std::size_t lg_size; // 陣列大小的對數值(2的lg_size次方)
// 訪問指定索引的槽位
slot &at(std::size_t k) {
return (reinterpret_cast<slot *>(reinterpret_cast<void *>(this + 1)))[k];
}
// 返回陣列的大小(2的lg_size次方)
std::size_t size() const { return std::size_t(1) << lg_size; }
// 返回陣列的掩碼,用於快速取模運算
std::size_t mask() const { return size() - 1; }
// 計算雜湊值的起始位置,使用高位元進行二次雜湊
std::size_t start(std::size_t h) const {
return h >> (8 * sizeof(std::size_t) - lg_size);
}
};
// 雜湊表槽位,使用原子操作保證執行緒安全
struct slot {
std::atomic<key_type> key; // 執行緒 ID 鍵,使用原子操作
void *ptr; // 指向執行緒本地資料的指針
// 檢查槽位是否為空
bool empty() const {
return key.load(std::memory_order_relaxed) == key_type();
}
// 檢查槽位是否匹配指定的鍵
bool match(key_type k) const {
return key.load(std::memory_order_relaxed) == k;
}
// 嘗試佔用空槽位,使用比較交換操作保證原子性
bool claim(key_type k) {
// TODO: 可能需要佔用 ptr,因為 key_type 不保證適合字長
key_type expected = key_type();
return key.compare_exchange_strong(expected, k);
}
};
// 原子指針,指向雜湊表陣列的根節點
std::atomic<array *> my_root;
// 原子計數器,記錄當前執行緒本地物件的數量
std::atomic<std::size_t> my_count;
// 內部容器,存儲所有執行緒本地物件,使用填充防止偽共享
internal_collection_type my_locals;
// 原子布林值,標記是否正在進行容器插入操作(防止競爭條件)
std::atomic_bool my_locals_is_emplacing;
// 建構回調函數,用於建立執行緒本地物件的實例
std::function<T()> my_construct_callback;
// 創建指定大小的雜湊表陣列,使用對齊記憶體分配
void *create_array(std::size_t _size) {
// 計算需要的元素數量,按指針大小對齊
std::size_t nelements = (_size + sizeof(uintptr_t) - 1) / sizeof(uintptr_t);
// 使用陣列分配器分配記憶體
return array_allocator_type().allocate(nelements);
}
// 釋放雜湊表陣列的記憶體
void free_array(void *_ptr, std::size_t _size) {
// 計算需要釋放的元素數量,按指針大小對齊
std::size_t nelements = (_size + sizeof(uintptr_t) - 1) / sizeof(uintptr_t);
// 使用陣列分配器釋放記憶體
array_allocator_type().deallocate(reinterpret_cast<uintptr_t *>(_ptr),
nelements);
}
// 分配指定大小的雜湊表陣列
array *allocate(std::size_t lg_size) {
// 計算陣列中槽位的數量(2的lg_size次方)
std::size_t n = std::size_t(1) << lg_size;
// 分配記憶體:陣列結構 + n個槽位
array *a =
static_cast<array *>(create_array(sizeof(array) + n * sizeof(slot)));
a->lg_size = lg_size;
// 將槽位記憶體初始化為零
std::memset(a + 1, 0, n * sizeof(slot));
return a;
}
// 釋放雜湊表陣列
void deallocate(array *a) {
// 計算陣列中槽位的數量
std::size_t n = std::size_t(1) << (a->lg_size);
// 釋放整個陣列的記憶體
free_array(static_cast<void *>(a),
std::size_t(sizeof(array) + n * sizeof(slot)));
}
// 雜湊表查找函數,查找當前執行緒的本地儲存
void *table_lookup(bool &exists) {
// 獲取當前執行緒的 ID 作為查找鍵
const key_type k = std::this_thread::get_id();
void *found = nullptr;
// 計算執行緒 ID 的雜湊值
std::size_t h = std::hash<key_type>{}(k);
// 從根陣列開始,遍歷所有雜湊表陣列
for (array *r = my_root.load(std::memory_order_acquire); r; r = r->next) {
std::size_t mask = r->mask();
// 使用開放尋址法進行探測
for (std::size_t i = r->start(h);; i = (i + 1) & mask) {
slot &s = r->at(i);
// 如果槽位為空,說明鍵不存在
if (s.empty())
break;
// 如果找到匹配的鍵
if (s.match(k)) {
// 如果在頂層陣列中找到,直接返回
if (r == my_root.load(std::memory_order_acquire)) {
// 在頂層成功找到
exists = true;
return s.ptr;
} else {
// 在其他層級找到,需要插入到頂層以提高存取效率
exists = true;
found = s.ptr;
goto insert;
}
}
}
}
// 鍵尚不存在。表中槽位的密度不會超過 0.5,
// 如果即將超過此密度,將分配一個大小為當前表兩倍的新表,
// 並將其作為新的根表進行交換。因此保證能找到空槽位。
exists = false;
found = [&] {
padded_element *new_element = nullptr;
{
bool expected = false;
while (!my_locals_is_emplacing.compare_exchange_weak(
expected, true, std::memory_order_acquire,
std::memory_order_relaxed)) {
expected = false;
std::this_thread::yield();
}
// 在鎖內獲取新元素的指標
my_locals.emplace_back();
new_element = &my_locals.back(); // 在鎖內獲取指標
my_locals_is_emplacing.store(false, std::memory_order_release);
}
// 在鎖外進行初始化
if (my_construct_callback) {
new (new_element->value()) T(my_construct_callback());
}
return new_element->value_committed();
}();
{
// 增加計數器並檢查是否需要擴展雜湊表
std::size_t c = ++my_count;
array *r = my_root.load(std::memory_order_acquire);
// 如果根陣列不存在或當前元素數量超過陣列大小的一半,需要擴展
if (!r || c > r->size() / 2) {
// 計算新陣列的大小(對數值)
std::size_t s = r ? r->lg_size : 2;
while (c > std::size_t(1) << (s - 1))
++s;
// 分配新的更大的陣列
array *a = allocate(s);
// 嘗試將新陣列設為根陣列
for (;;) {
a->next = r;
array *new_r = r;
// 使用原子比較交換更新根指針
if (my_root.compare_exchange_strong(new_r, a))
break;
if (new_r->lg_size >= s) {
// 其他執行緒已插入了同等或更大的陣列,
// 所以我們的陣列是多餘的
deallocate(a);
break;
}
r = new_r;
}
}
}
insert:
// 無論是在舊表中找到槽位,還是在此層級插入,
// 都已在總計中考慮。保證有空間容納,且它不存在,
// 所以搜索空槽位並使用它。
array *ir = my_root.load(std::memory_order_acquire);
std::size_t mask = ir->mask();
// 在頂層陣列中搜索空槽位進行插入
for (std::size_t i = ir->start(h);; i = (i + 1) & mask) {
slot &s = ir->at(i);
// 如果找到空槽位,嘗試佔用它
if (s.empty()) {
if (s.claim(k)) {
// 成功佔用槽位,設置指針並返回
s.ptr = found;
return found;
}
}
}
}
public:
// 公開類型定義,符合標準容器介面
using value_type = T;
using allocator_type = Allocator;
using size_type = typename internal_collection_type::size_type;
using difference_type = typename internal_collection_type::difference_type;
using reference = value_type &;
using const_reference = const value_type &;
using pointer = typename allocator_traits_type::pointer;
using const_pointer = typename allocator_traits_type::const_pointer;
// 建構函數:可選的建構回調函數用於自訂物件初始化
ThreadLocalStorage(std::function<T()> construct_callback = nullptr)
: my_root{nullptr}, my_count{0},
my_construct_callback{std::move(construct_callback)},
my_locals_is_emplacing{false} {}
// 解構函數:清理所有資源
~ThreadLocalStorage() { clear(); }
// 獲取當前執行緒的本地物件引用(簡化版本)
reference local() {
bool exists;
return local(exists);
}
// 獲取當前執行緒的本地物件引用,並返回是否已存在
reference local(bool &exists) {
void *ptr = this->table_lookup(exists);
return *(T *)ptr;
}
// 清理所有資源,釋放雜湊表和執行緒本地物件
void clear() {
// 釋放所有雜湊表陣列
while (array *r = my_root.load(std::memory_order_relaxed)) {
my_root.store(r->next, std::memory_order_relaxed);
deallocate(r);
}
// 重置計數器
my_count.store(0, std::memory_order_relaxed);
// 清空本地物件容器(自動調用解構函數)
my_locals.clear();
}
// 對所有已建構的執行緒本地物件執行指定函數
void for_each(std::function<void(const T &)> func) const {
for (auto &elem : my_locals) {
if (elem.is_built) {
func(*(elem.value()));
}
}
}
// 對所有已建構的執行緒本地物件執行指定函數(可修改版本)
void for_each(std::function<void(T &)> func) {
for (auto &elem : my_locals) {
if (elem.is_built) {
func(*(elem.value()));
}
}
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment