Last active
February 21, 2022 03:09
-
-
Save krayfaus/671c9432777b1d8bfca83f1ca850b7fe to your computer and use it in GitHub Desktop.
S+ Tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ---------------------------------------------------------------- | |
// S+ Tree | |
// | |
// Created by: Sergey Slotin | |
// Documented at: | |
// https://en.algorithmica.org/hpc/data-structures/s-tree | |
// https://twitter.com/sergey_slotin/status/1494349254730690561 | |
// | |
// ---------------------------------------------------------------- | |
// | |
// Rewritten by: Ítalo Cadeu (@krayfaus) | |
// | |
// This is an initial attempt to transform the amazing S+ Tree data structure, | |
// created by Sergey Slotin into a templated header-only C++ library. | |
// | |
// I hope the code bellow is usable, I've tried my best to understand the original algorithm, | |
// but as I've never worked into production it may contain begginer mistakes. | |
// | |
// https://gist.github.com/krayfaus/671c9432777b1d8bfca83f1ca850b7fe | |
// ---------------------------------------------------------------- | |
// As acknowledged by Sergey before 'online compilers' are not the best way | |
// to properly benchmark memory intensive algorithms, | |
// so please take the output with a grain of salt. | |
// If possible compile and run the code on your own machine. | |
// The code compiles on all major compilers: Clang, GCC and MSVC. | |
#if defined(_MSC_VER) | |
/* Microsoft C/C++-compatible compiler */ | |
#include <intrin.h> | |
#define NOMINMAX | |
#define WIN32_LEAN_AND_MEAN | |
#include <windows.h> | |
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) | |
/* GCC-compatible compiler, targeting x86/x86-64 */ | |
#include <sys/mman.h> | |
#include <x86intrin.h> | |
#else | |
#error "Unknown or Unsupported Platform." | |
#endif | |
#include <cstdlib> | |
#include <limits> | |
#include <random> | |
#include <time.h> | |
#include <fmt/core.h> | |
// ---------------------------------------------------------------- | |
// Aliases: | |
using s32 = signed int; | |
using u32 = unsigned int; | |
template <typename T, size_t Size> | |
using c_array = T[Size]; | |
constexpr auto k_Infinity = std::numeric_limits<int>::max(); | |
// ---------------------------------------------------------------- | |
namespace | |
{ | |
[[nodiscard]] auto allocate_memory(size_t alignment, size_t size) -> void * | |
{ | |
void *data = nullptr; | |
#if defined(_MSC_VER) | |
data = VirtualAlloc(NULL, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); | |
#elif defined(__linux__) | |
data = std::aligned_alloc(alignment, size); | |
madvise(data, size, MADV_HUGEPAGE); | |
#else | |
#error "Unknown or Unsupported Platform." | |
#endif | |
return data; | |
} | |
} // namespace | |
// ---------------------------------------------------------------- | |
namespace | |
{ | |
/** Calculate the number of B-element blocks in a layer. */ | |
template <s32 BlockSize> | |
[[nodiscard]] constexpr auto calculate_block_count(s32 n) -> s32 | |
{ | |
return (n + BlockSize - 1) / BlockSize; | |
} | |
/** Calculate the number of keys on the layer previous to one with nth element. */ | |
template <s32 BlockSize> | |
[[nodiscard]] constexpr auto calculate_previous_layer_key_count(s32 n) -> s32 | |
{ | |
return (calculate_block_count<BlockSize>(n) + BlockSize) / (BlockSize + 1) * BlockSize; | |
} | |
/** Calculate the height of a balanced n-key B+ tree. */ | |
template <s32 BlockSize> | |
[[nodiscard]] constexpr auto calculate_height(s32 n) -> s32 | |
{ | |
return n <= BlockSize ? 1 : calculate_height<BlockSize>(calculate_previous_layer_key_count<BlockSize>(n)) + 1; | |
} | |
/** Calculate the offset of the h layer on a B+ tree (0 is the largest). */ | |
template <s32 BlockSize, s32 KeyCount> | |
[[nodiscard]] constexpr auto calculate_offset(s32 h) -> s32 | |
{ | |
// expect(h >= 0, "Invalid tree height."); | |
s32 k = 0; | |
s32 n = KeyCount; | |
while (h--) | |
{ | |
k += calculate_block_count<BlockSize>(n) * BlockSize; | |
n = calculate_previous_layer_key_count<BlockSize>(n); | |
} | |
return k; | |
} | |
} // namespace | |
// ---------------------------------------------------------------- | |
namespace | |
{ | |
using reg = __m256i; | |
void permute(s32 *node) | |
{ | |
reg const mask = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); | |
reg *middle = (reg *)(node + 4); | |
reg x = _mm256_loadu_si256(middle); | |
x = _mm256_permutevar8x32_epi32(x, mask); | |
_mm256_storeu_si256(middle, x); | |
} | |
auto direct_rank(reg x, s32 *y) -> u32 | |
{ | |
reg a = _mm256_load_si256((reg *)y); | |
reg b = _mm256_load_si256((reg *)(y + 8)); | |
reg ca = _mm256_cmpgt_epi32(a, x); | |
reg cb = _mm256_cmpgt_epi32(b, x); | |
#if defined(_MSC_VER) | |
auto tca = _mm256_castsi256_ps(ca); | |
auto tcb = _mm256_castsi256_ps(cb); | |
s32 mb = _mm256_movemask_ps(tcb); | |
s32 ma = _mm256_movemask_ps(tca); | |
#else | |
s32 mb = _mm256_movemask_ps((__m256)cb); | |
s32 ma = _mm256_movemask_ps((__m256)ca); | |
#endif | |
u32 mask = (1 << 16); | |
mask |= static_cast<u32>(mb << 8); | |
mask |= static_cast<u32>(ma); | |
#if defined(_MSC_VER) | |
return _tzcnt_u32(mask); | |
#else | |
return __tzcnt_u32(mask); | |
#endif | |
} | |
auto permuted_rank(reg x, s32 *y) -> u32 | |
{ | |
reg a = _mm256_load_si256((reg *)y); | |
reg b = _mm256_load_si256((reg *)(y + 8)); | |
reg ca = _mm256_cmpgt_epi32(a, x); | |
reg cb = _mm256_cmpgt_epi32(b, x); | |
reg c = _mm256_packs_epi32(ca, cb); | |
u32 mask = static_cast<u32>(_mm256_movemask_epi8(c)); | |
#if defined(_MSC_VER) | |
return _tzcnt_u32(mask); | |
#else | |
return __tzcnt_u32(mask); | |
#endif | |
} | |
template <s32 BlockSize, s32 TreeHeight, s32 ElementCount> | |
auto lower_bound(s32 value, s32 *tree) -> s32 | |
{ | |
reg x = _mm256_set1_epi32(value - 1); | |
u32 k = 0; | |
for (s32 h = TreeHeight - 1; h > 0; h--) | |
{ | |
u32 const i = permuted_rank(x, tree + calculate_offset<BlockSize, ElementCount>(h) + k); | |
k = k * (BlockSize + 1) + (i << 3); | |
} | |
u32 i = direct_rank(x, tree + k); | |
return tree[k + i]; | |
} | |
} // namespace | |
// ---------------------------------------------------------------- | |
namespace aethelwerka | |
{ | |
template <typename ElementType, s32 ElementCount, s32 BlockSize> | |
class static_tree | |
{ | |
public: | |
// Aliases: | |
using value_type = ElementType; | |
using pointer_type = value_type *; | |
// Constants: | |
static constexpr auto value_size = sizeof(value_type); // Size in bytes of value_type. | |
static constexpr s32 block_size = BlockSize; // Cache line. | |
static constexpr s32 element_count = ElementCount; // Element count on input array, key count on btree. | |
// Height of a balanced n-key B+ tree. | |
static constexpr s32 tree_height = calculate_height<block_size>(element_count); | |
// Tree size is the offset of the (non-existent) layer H ("height"). | |
static constexpr s32 tree_size = calculate_offset<block_size, element_count>(tree_height); | |
public: | |
// Empty constructor. | |
static_tree() noexcept | |
: tree_data(nullptr) | |
, is_initialized(false) | |
{ | |
} | |
// Initialize tree. | |
[[nodiscard]] bool initialize(c_array<value_type, element_count> input_data) noexcept | |
{ | |
// expect(!is_initialized); | |
// expect(input_data); | |
// expect(element_count > 0); | |
s32 const page_alignment = 1 << 21; // Page size in bytes (2MB). | |
// We can only allocate a whole number of pages. | |
s32 const page_size = (value_size * tree_size + page_alignment - 1) / page_alignment * page_alignment; | |
// Allocate memory. | |
tree_data = (s32*) allocate_memory(page_alignment, page_size); | |
if (!tree_data) | |
{ | |
// Couldn't allocate memory. | |
return false; | |
} | |
// Pad the tree with infinities. | |
for (s32 i = element_count; i < tree_size; ++i) | |
{ | |
tree_data[i] = k_Infinity; | |
} | |
// Copy data from input_data array to tree_data. | |
memcpy(tree_data, input_data, value_size * element_count); | |
// Build the internal nodes, layer by layer. | |
for (s32 h = 1; h < tree_height; ++h) | |
{ | |
for (s32 i = 0; i < calculate_offset<block_size, element_count>(h + 1) - calculate_offset<block_size, element_count>(h); ++i) | |
{ | |
s32 k = i / block_size; | |
s32 const j = i - k * block_size; | |
k = k * (block_size + 1) + j + 1; // Compare to the right of the key. | |
for (s32 l = 0; l < h - 1; ++l) // And then always to the left. | |
{ | |
k *= block_size + 1; | |
} | |
// Pad the rest with infinities if the key doesn't exist: | |
tree_data[calculate_offset<block_size, element_count>(h) + i] = | |
k * block_size < element_count ? tree_data[k * block_size] : k_Infinity; | |
} | |
} | |
// Permute every tree node for faster query time (trick to avoid permuting avx2 later). | |
for (s32 i = calculate_offset<block_size, element_count>(1); i < tree_size; i += block_size) | |
{ | |
permute(tree_data + i); | |
} | |
// Tree is properly initialized now, and ready to use. | |
is_initialized = true; | |
return true; | |
} | |
[[nodiscard]] s32 search(value_type value) | |
{ | |
return lower_bound<block_size, tree_height, element_count>(value, tree_data); | |
} | |
private: | |
// Pointer to tree data. | |
pointer_type tree_data; | |
// Status of the tree data. | |
bool is_initialized; | |
}; | |
} // namespace aethelwerka | |
// ---------------------------------------------------------------- | |
template <s32 ArrayLenght> | |
auto baseline(s32 value, s32 array[ArrayLenght]) -> s32 | |
{ | |
auto const array_first = array; | |
auto const array_last = array + ArrayLenght; | |
return *std::lower_bound(array_first, array_last, value); | |
} | |
// ---------------------------------------------------------------- | |
int main(int, char **) | |
{ | |
using fmt::print; | |
using aethelwerka::static_tree; | |
// ---------------------------------------------------------------- | |
using ElementType = int; | |
constexpr auto ElementSize = sizeof(ElementType); | |
constexpr auto ElementCount = (1 << 16); | |
constexpr auto IterationCount = (1 << 22); | |
// ---------------------------------------------------------------- | |
// Data arrays (they're so big they need to go on the heap): | |
static c_array<ElementType, ElementCount> input_data; | |
static c_array<int, IterationCount> check_indices; | |
// ---------------------------------------------------------------- | |
print("Lenght = {}, Iterations = {}\n", ElementCount, IterationCount); | |
std::mt19937 rng(0); | |
// Fill input data. | |
input_data[0] = k_Infinity; | |
for (s32 i = 1; i < ElementCount; ++i) | |
{ | |
input_data[i] = rng() % (1 << 30); | |
} | |
// Random indices. | |
for (s32 i = 0; i < IterationCount; ++i) | |
{ | |
check_indices[i] = rng() % (1 << 30); | |
} | |
// ---------------------------------------------------------------- | |
// TreeBlockSize: 16 elements; sizeof(s32) * 16 = 64 bytes (cache line is typically 64 bytes). | |
constexpr auto TreeBlockSize = 64 / ElementSize; | |
auto tree = static_tree<ElementType, ElementCount, TreeBlockSize>{}; | |
if (!tree.initialize(input_data)) | |
{ | |
// Couldn't initialize tree. | |
return -1; | |
} | |
// ---------------------------------------------------------------- | |
// The measurement code bellow is ugly and may not give proper metrics. | |
// ---------------------------------------------------------------- | |
double x = 0.0; | |
{ | |
clock_t start = clock(); | |
s32 checksum = 0; | |
for (s32 i = 0; i < IterationCount; ++i) | |
{ | |
checksum ^= baseline<ElementCount>(check_indices[i], input_data); | |
} | |
double seconds = double(clock() - start) / CLOCKS_PER_SEC; | |
print("Checksum: {}\n", checksum); | |
x = 1e9 * seconds / IterationCount; | |
} | |
// ---------------------------------------------------------------- | |
double y = 0.0; | |
{ | |
clock_t start = clock(); | |
s32 checksum = 0; | |
for (s32 i = 0; i < IterationCount; ++i) | |
{ | |
checksum ^= tree.search(check_indices[i]); | |
} | |
double seconds = double(clock() - start) / CLOCKS_PER_SEC; | |
print("Checksum: {}\n", checksum); | |
y = 1e9 * seconds / IterationCount; | |
} | |
print("std::lower_bound: {:.2f}\n", x); | |
print("S+ tree: {:.2f}\n", y); | |
print("Speedup: {:.2f}\n", x / y); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment