Last active
January 1, 2018 17:14
-
-
Save wwylele/0bc1ca527fa54ecfb4c5bf2e78a6c2a5 to your computer and use it in GitHub Desktop.
"Bit trie" a trie-like structure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <climits> | |
#include <cstdlib> | |
#include <iostream> | |
#include <stdexcept> | |
#include <string> | |
#include <vector> | |
/* | |
bit_trie | |
bit_tries are associative containers that store elements formed | |
by the combination of a key value and a mapped value, and which | |
allows for fast retrieval of individual elements based on their keys. | |
Key | |
Type of the key values. Each element in an bit_tire is uniquely | |
identified by its key value. | |
Value | |
Type of the mapped values. | |
BitTester | |
A binary function object type that takes an object of type Key | |
and a size_t integer (bit address) as arguments and returns a bool. | |
The bit address is an integer less than bit length. The function | |
should return a consistent result for the same key and the same | |
bit address. The results for two differents key should be different | |
for at least one bit address. | |
*/ | |
template < | |
typename Key, | |
typename Value, | |
typename BitTester | |
> | |
class bit_trie { | |
struct Branch{ | |
int offset; | |
bool end; | |
}; | |
struct Node { | |
std::size_t bit_address; | |
Branch left; | |
Branch right; | |
Key key; // this can be removed if use only at_no_verify to retrieve elements | |
Value value; | |
}; | |
std::vector<Node> nodes; | |
BitTester tester; | |
Node& at_node(const Key& key) { | |
size_t cur_pos = 0; | |
Branch next = nodes[0].left; | |
while (true) { | |
cur_pos += next.offset; | |
if (next.end) | |
return nodes[cur_pos]; | |
if (tester(key, nodes[cur_pos].bit_address)) { | |
next = nodes[cur_pos].right; | |
} else { | |
next = nodes[cur_pos].left; | |
} | |
} | |
} | |
void build(typename std::vector<Node>::iterator begin, typename std::vector<Node>::iterator end, std::size_t bit_length) { | |
// counting number of elements and numbers passing each bit test | |
std::vector<std::size_t> pass_count(bit_length); | |
std::size_t count = 0; | |
std::for_each(begin, end, [&](const Node& node){ | |
for (std::size_t bit = 0; bit < bit_length; ++bit) { | |
if (tester(node.key, bit)) | |
++pass_count[bit]; | |
} | |
++count; | |
}); | |
if (count == 1) | |
return; | |
// find the best address that partition the elements evenly | |
long badness = LONG_MAX; | |
std::size_t best_address; | |
for (std::size_t address = 0; address < bit_length; ++ address) { | |
long current = std::abs((long)(pass_count[address] - count / 2)); | |
if (current < badness) { | |
badness = current; | |
best_address = address; | |
} | |
} | |
// partition | |
auto partition_pos = std::partition(begin, end, [&](const Node& node){ | |
return !tester(node.key, best_address); | |
}); | |
auto partition_distance = std::distance(begin, partition_pos); | |
if (partition_distance == count || partition_distance == 0) | |
throw std::invalid_argument("Duplicated keys."); | |
// build trie for each partition | |
build(begin, partition_pos, bit_length); | |
build(partition_pos, end, bit_length); | |
// let right guiding node be the first level node | |
partition_pos->right = partition_pos->left; | |
partition_pos->left = begin->left; | |
partition_pos->left.offset -= partition_distance; | |
partition_pos->bit_address = best_address; | |
// and left guiding node be the main guiding node | |
begin->left.offset = partition_distance; | |
begin->left.end = false; | |
} | |
public: | |
template <typename InputInterator> | |
bit_trie( | |
InputInterator begin, | |
InputInterator end, | |
std::size_t bit_length, | |
const BitTester& tester_ = BitTester() | |
) : nodes(std::distance(begin, end)), tester(tester_) { | |
constexpr std::size_t invalid_bit_address = static_cast<std::size_t>(-1); | |
std::transform(begin, end, nodes.begin(), [](std::pair<Key, Value> pair){ | |
return Node { | |
invalid_bit_address, | |
{0, true}, | |
{0, true}, | |
pair.first, | |
pair.second | |
}; | |
}); | |
if (nodes.empty()) | |
throw std::length_error("Empty range."); | |
build(nodes.begin(), nodes.end(), bit_length); | |
} | |
Value& at(const Key& key) { | |
Node& node = at_node(key); | |
if (node.key == key) | |
return node.value; | |
throw std::out_of_range("Element not found."); | |
} | |
Value& at_no_verify(const Key& key) { | |
return at_node(key).value; | |
} | |
}; | |
/////////////////////////////////////////////////////////////////////////////// | |
/// TEST | |
const std::pair<std::string, std::string> example_map[] = { | |
{"Hydrogen", "氢"}, | |
{"Helium", "氦"}, | |
{"Lithium", "锂"}, | |
{"Beryllium", "铍"}, | |
{"Boron", "硼"}, | |
{"Carbon", "碳"}, | |
{"Nitrogen", "氮"}, | |
{"Oxygen", "氧"}, | |
{"Fluorine", "氟"}, | |
{"Neon", "氖"}, | |
{"Sodium", "钠"}, | |
{"Magnesium", "镁"}, | |
{"Aluminium", "铝"}, | |
{"Silicon", "硅"}, | |
{"Phosphorus", "磷"}, | |
{"Sulfur", "硫"}, | |
{"Chlorine", "氯"}, | |
{"Argon", "氩"}, | |
{"Potassium", "钾"}, | |
{"Calcium", "钙"}, | |
{"Scandium", "钪"}, | |
{"Titanium", "钛"}, | |
{"Vanadium", "钒"}, | |
{"Chromium", "铬"}, | |
{"Manganese", "锰"}, | |
{"Iron", "铁"}, | |
{"Cobalt", "钴"}, | |
{"Nickel", "镍"}, | |
{"Copper", "铜"}, | |
{"Zinc", "锌"}, | |
{"Gallium", "镓"}, | |
{"Germanium", "锗"}, | |
{"Arsenic", "砷"}, | |
{"Selenium", "硒"}, | |
{"Bromine", "溴"}, | |
{"Krypton", "氪"}, | |
{"Rubidium", "铷"}, | |
{"Strontium", "锶"}, | |
{"Yttrium", "钇"}, | |
{"Zirconium", "锆"}, | |
{"Niobium", "铌"}, | |
{"Molybdenum", "钼"}, | |
{"Technetium", "锝"}, | |
{"Ruthenium", "钌"}, | |
{"Rhodium", "铑"}, | |
{"Palladium", "钯"}, | |
{"Silver", "银"}, | |
{"Cadmium", "镉"}, | |
{"Indium", "铟"}, | |
{"Tin", "锡"}, | |
{"Antimony", "锑"}, | |
{"Tellurium", "碲"}, | |
{"Iodine", "碘"}, | |
{"Xenon", "氙"}, | |
{"Caesium", "铯"}, | |
{"Barium", "钡"}, | |
{"Hafnium", "铪"}, | |
{"Tantalum", "钽"}, | |
{"Tungsten", "钨"}, | |
{"Rhenium", "铼"}, | |
{"Osmium", "锇"}, | |
{"Iridium", "铱"}, | |
{"Platinum", "铂"}, | |
{"Gold", "金"}, | |
{"Mercury", "汞"}, | |
{"Thallium", "铊"}, | |
{"Lead", "铅"}, | |
{"Bismuth", "铋"}, | |
{"Polonium", "钋"}, | |
{"Astatine", "砹"}, | |
{"Radon", "氡"}, | |
{"Francium", "钫"}, | |
{"Radium", "镭"}, | |
{"Rutherfordium", "鑪"}, | |
{"Meitnerium", "䥑"}, | |
{"Darmstadtium", "鐽"}, | |
{"Roentgenium", "錀"}, | |
{"Copernicium", "鎶"}, | |
{"Flerovium", "鈇"}, | |
{"Livermorium", "鉝"}, | |
{"Lanthanum", "镧"}, | |
{"Cerium", "铈"}, | |
{"Praseodymium", "镨"}, | |
{"Neodymium", "钕"}, | |
{"Promethium", "钷"}, | |
{"Samarium", "钐"}, | |
{"Europium", "铕"}, | |
{"Gadolinium", "钆"}, | |
{"Terbium", "铽"}, | |
{"Dysprosium", "镝"}, | |
{"Holmium", "钬"}, | |
{"Erbium", "铒"}, | |
{"Thulium", "铥"}, | |
{"Ytterbium", "镱"}, | |
{"Lutetium", "镏"}, | |
{"Actinium", "锕"}, | |
{"Thorium", "钍"}, | |
{"Protactinium", "镤"}, | |
{"Uranium", "铀"}, | |
{"Neptunium", "镎"}, | |
{"Plutonium", "钚"}, | |
{"Americium", "銤"}, | |
{"Curium", "锔"}, | |
{"Berkelium", "锫"}, | |
{"Californium", "锎"}, | |
{"Einsteinium", "锿"}, | |
{"Fermium", "镄"}, | |
{"Mendelevium", "钔"}, | |
{"Nobelium", "锘"}, | |
{"Lawrencium", "铹"} | |
}; | |
bool string_tester(const std::string& key, std::size_t position) { | |
std::size_t byte = position >> 3; | |
if (byte >= key.size()) | |
return false; | |
return (key[byte] >> (position & 7)) & 1; | |
} | |
int main() { | |
std::size_t max_bit_length = 0; | |
for (const auto& pair : example_map) { | |
if (pair.first.size() > max_bit_length) | |
max_bit_length = pair.first.size(); | |
} | |
max_bit_length *= 8; | |
bit_trie<std::string, std::string, decltype(&string_tester)> | |
example(std::begin(example_map), std::end(example_map), max_bit_length, &string_tester); | |
for (const auto& pair : example_map) { | |
auto value = example.at_no_verify(pair.first); | |
std::cout << "(" << (value == pair.second ? "✓" : "✗") << ")" | |
<< pair.first << " : " << value << std::endl; | |
} | |
std::string bad_key = "abcdefg"; | |
try { | |
auto value = example.at(bad_key); | |
std::cout << bad_key << " : " << value << std::endl; | |
} catch (std::out_of_range e) { | |
std::cout << bad_key << " : " << e.what() << std::endl; | |
} | |
auto value = example.at_no_verify(bad_key); | |
std::cout << bad_key << " : " << value << std::endl; | |
return 0; | |
} |
@yuriks this trie doesn't match prefix of the keys. Each level choose a bit in the middle of the key to match. The bit choosing method is to find the bit that balance the trie most. So the it doesn't matter if I reverse the key.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Random comment about the test data: You'd probably actually get much better utilization of the trie if you store the strings in reverse (
Uranium
->muinarU
) because while there's not a lot of common prefixes, a lot of words end with "-ium"