Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Last active December 23, 2025 06:19
Show Gist options
  • Select an option

  • Save jweinst1/5c6fb139ff86a17cecd87504312c9f2b to your computer and use it in GitHub Desktop.

Select an option

Save jweinst1/5c6fb139ff86a17cecd87504312c9f2b to your computer and use it in GitHub Desktop.
Automatic Sorting in Rust with bit RZ and LZ instructions
#include <array>
#include <cstdint>
#include <cstddef>
#include <cmath>
#include <cstdio>
#include <climits>
#include <vector>
#include <cassert>
#include <random>
#include <chrono>
#include <iostream>
#include <limits>
#include <cassert>
#include <memory>
#include <bitset>
#include <optional>
std::optional<uint64_t> get_eq_or_gt(uint64_t bits, uint64_t target) {
const uint64_t shifted = bits >> target;
if (!shifted) {
return std::nullopt;
}
return __builtin_ctzll(shifted) + target;
}
std::optional<uint64_t> get_eq_or_lt(uint64_t bits, uint64_t target) {
const uint64_t clamped = target == 63 ? (bits & std::numeric_limits<uint64_t>::max()) : (bits & ((uint64_t{1} << (target + 1)) - 1));
if (!clamped) {
return std::nullopt;
}
return 63 - __builtin_clzll(clamped);
}
bool slc_test_bit(uint64_t* bits, size_t target) {
const size_t place = target >> 6;
const size_t offset = target & 63;
return bits[place] & (uint64_t{1} << offset);
}
void slc_set_bit(uint64_t* bits, size_t target) {
const size_t place = target >> 6;
const size_t offset = target & 63;
bits[place] |= (uint64_t{1} << offset);
}
std::optional<uint64_t> slc_get_eq_gt(uint64_t* bits, size_t size, uint64_t target) {
const uint64_t place = target >> 6;
const uint64_t offset = target & 63;
std::optional<uint64_t> res = get_eq_or_gt(bits[place], offset);
if (res.has_value()) {
return res.value() + (place << 6);
}
for (unsigned i = place + 1; i < size; ++i)
{
std::optional<uint64_t> loop_res = get_eq_or_gt(bits[i], 0);
if (loop_res.has_value()) {
return loop_res.value() + (i << 6);
}
}
return std::nullopt;
}
std::optional<uint64_t> slc_get_eq_lt(uint64_t* bits, size_t size, uint64_t target) {
const uint64_t place = target >> 6;
const uint64_t offset = target & 63;
std::optional<uint64_t> res = get_eq_or_lt(bits[place], offset);
if (res.has_value()) {
return res.value() + (place << 6);
}
for (int i = place - 1; i >= 0; --i)
{
std::optional<uint64_t> loop_res = get_eq_or_lt(bits[i], 63);
if (loop_res.has_value()) {
return loop_res.value() + (i << 6);
}
}
return std::nullopt;
}
struct ValPtrNode {
std::optional<size_t> val = std::nullopt;
void* ptr = nullptr;
};
struct Byte8Node {
uint64_t data[4] = {0};
ValPtrNode children[256] = {};
};
/*
// todo change to depth approach
void insert_into_tree_u32(Byte8Node* node, uint32_t key, uint32_t depth) {
const uint8_t keyParts[4] = {
key & 0xff, (key >> 8) & 0xff, (key >> 16) & 0xff, (key >> 24) & 0xff
};
if (!slc_test_bit(node->data, keyParts[0])) {
slc_set_bit(node->data, keyParts[0]);
node->children[keyParts[0]].val = std::make_optional<size_t>(key);
return;
}
}*/
// tests start here
template<class T>
static void checkOptionalEq(const std::optional<T>& lfs, const T& rfs, unsigned lineno) {
if (!lfs.has_value()) {
fprintf(stderr, "FAIL line %u optional has no value\n", lineno);
return;
}
if (lfs.value() != rfs) {
fprintf(stderr, "FAIL NEQ line %u lfs %zu rfs %zu\n", lineno, (size_t)lfs.value(), (size_t)rfs);
}
}
static void checkCondition(int cond, const char* express, unsigned lineno) {
if (!cond) {
fprintf(stderr, "FAIL %s line %u\n", express, lineno);
}
}
#define CHECKIT(cond) checkCondition(cond, #cond, __LINE__)
#define CHECK_OPT(lfs, rfs) checkOptionalEq<uint64_t>(lfs, rfs, __LINE__)
static void test_get_eq_or_gt() {
CHECKIT(get_eq_or_gt(0b1010100, 3).value() == 4);
CHECKIT(get_eq_or_gt(0b1011100, 3).value() == 3);
CHECKIT(!get_eq_or_gt(0b1010100, 63).has_value());
CHECKIT(get_eq_or_gt(uint64_t{1} << 63, 63).value() == 63);
CHECKIT(!get_eq_or_gt(uint64_t{1} << 60, 63).has_value());
CHECKIT(get_eq_or_gt(0b1011111, 0).value() == 0);
CHECKIT(get_eq_or_gt(0b1011111, 1).value() == 1);
}
static void test_get_eq_or_lt() {
CHECKIT(get_eq_or_lt(0b1010100, 3).value() == 2);
CHECKIT(get_eq_or_lt(uint64_t{1} << 63, 63).value() == 63);
CHECKIT(get_eq_or_lt(uint64_t{1} << 0, 63).value() == 0);
CHECKIT(get_eq_or_lt(uint64_t{1} << 1, 63).value() == 1);
CHECKIT(get_eq_or_lt(0b1011110, 3).value() == 3);
CHECKIT(get_eq_or_lt(0b1011111, 0).value() == 0);
CHECKIT(get_eq_or_lt(0b1011111, 1).value() == 1);
}
static void test_slc_get_eq_gt() {
uint64_t mySets[4] = {0b1, 0b1101, 0b1001111110, 0b1};
CHECK_OPT(slc_get_eq_gt(mySets, 4, 0), 0);
CHECK_OPT(slc_get_eq_gt(mySets, 4, 1), 64);
CHECK_OPT(slc_get_eq_gt(mySets, 4, 64), 64);
CHECK_OPT(slc_get_eq_gt(mySets, 4, 65), 66);
CHECK_OPT(slc_get_eq_gt(mySets, 4, 75), 129);
CHECKIT(!slc_get_eq_gt(mySets, 4, 230).has_value());
}
static void test_slc_get_eq_lt() {
uint64_t mySets[4] = {0b1, 0b1101, 0b1001111110, 0b1};
CHECK_OPT(slc_get_eq_lt(mySets, 4, 0), 0);
CHECK_OPT(slc_get_eq_lt(mySets, 4, 1), 0);
CHECK_OPT(slc_get_eq_lt(mySets, 4, 64), 64);
CHECK_OPT(slc_get_eq_lt(mySets, 4, 65), 64);
CHECK_OPT(slc_get_eq_lt(mySets, 4, 220), 192);
}
static void test_grouped_insert() {
uint64_t myArray[256][4] = {};
uint32_t key = 0;
for (; key < (1024 * 1); key += 1)
{
printf("key is %u\n", key);
const uint32_t keyParts[4] = {
key & 0xff, (key >> 8) & 0xff, (key >> 16) & 0xff, (key >> 24) & 0xff
};
const uint8_t indexedNumber = (keyParts[0] >> 6) | ((keyParts[1] >> 6) << 2) | ((keyParts[2] >> 6) << 4) | ((keyParts[3] >> 6) << 6);
uint64_t* spot = myArray[indexedNumber];
// first we check how many was there
const bool resultCheck[4] = {
(bool)(spot[0] & (uint64_t{1} << (keyParts[0] & 63))),
(bool)(spot[1] & (uint64_t{1} << (keyParts[1] & 63))),
(bool)(spot[2] & (uint64_t{1} << (keyParts[2] & 63))),
(bool)(spot[3] & (uint64_t{1} << (keyParts[3] & 63)))
};
const uint8_t resultByte = (resultCheck[0] ? 1 : 0) | ((resultCheck[1] ? 1 : 0) << 1) | ((resultCheck[2] ? 1 : 0) << 2) | ((resultCheck[3] ? 1 : 0) << 3);
if (resultByte != 0) {
std::cout << std::bitset<8>(resultByte) << " coll spot " << (size_t)indexedNumber << " \n";
} else {
std::cout << "all good for " << key << "\n";
}
spot[0] |= uint64_t{1} << (keyParts[0] & 63);
spot[1] |= uint64_t{1} << (keyParts[1] & 63);
spot[2] |= uint64_t{1} << (keyParts[2] & 63);
spot[3] |= uint64_t{1} << (keyParts[3] & 63);
}
}
int main(int argc, char const *argv[])
{
test_get_eq_or_gt();
test_get_eq_or_lt();
test_slc_get_eq_gt();
test_slc_get_eq_lt();
test_grouped_insert();
return 0;
}
fn get_eq_or_gt(num:u64, target:u64) -> Option<u64> {
if target > 63 {
return None;
}
let shifted = num >> target;
if shifted == 0 {
return None;
}
return Some(shifted.trailing_zeros() as u64 + target);
}
fn get_eq_or_lt(num:u64, target:u64) -> Option<u64> {
let clamped = if target == 63 { num & u64::MAX } else { num & (((1 as u64) << target) - 1)};
if clamped == 0 {
return None;
}
return Some( 63 - clamped.leading_zeros() as u64);
}
fn slc_set(bits:&mut [u64], target:u64) {
let bit_place = target >> 6;
let bit_off = target & 63;
bits[bit_place as usize] |= (1 as u64) << bit_off;
}
fn slc_test(bits:&[u64], target:u64) -> bool {
let bit_place = target >> 6;
let bit_off = target & 63;
(bits[bit_place as usize] & (1 as u64) << bit_off) != 0
}
fn slc_get_eq_gt(bits:&[u64], target:u64) -> Option<u64> {
//let bit_count = (bits.len() * 64) - 1;
let bit_place = target >> 6;
let bit_off = target & 63;
match get_eq_or_gt(bits[bit_place as usize], bit_off) {
Some(v) => {return Some(v + (bit_place << 6))},
None => {
for i in ((bit_place + 1) as usize)..bits.len() {
match get_eq_or_gt(bits[i as usize], 0) {
Some(v) => {return Some(v + (i << 6) as u64)},
None => continue
}
}
}
}
return None;
}
fn slc_get_eq_lt(bits:&[u64], target:u64) -> Option<u64> {
//let bit_count = (bits.len() * 64) - 1;
let bit_place = target >> 6;
let bit_off = target & 63;
match get_eq_or_lt(bits[bit_place as usize], bit_off) {
Some(v) => {return Some(v + (bit_place << 6))},
None => {
if bit_place == 0 {
return None;
}
for i in (((bit_place - 1) as usize)..bits.len()).rev() {
match get_eq_or_lt(bits[i as usize], 63) {
Some(v) => {return Some(v + (i << 6) as u64)},
None => continue
}
}
}
}
return None;
}
// todo range / query function
/*fn slc_get_rng<'a>(bits:&'a[u64], lowest:u64, highest:u64) -> &'a[u64] {
// return bits;
}*/
/*
#[derive(Debug)]
struct Val8Node32 {
deeper:Option<Box<Val8Node32>>,
val:Option<u32>
}
// for init
const INIT_VAL8NODE32: Val8Node32 = Val8Node32 {deeper:None, val:None};
impl Val8Node32 {
fn insert(&mut self, num:u32, shift:u32) {
match self.val {
Some(v) => {
if v == num {
return;
}
self.deeper = Some(Box::new(INIT_VAL8NODE32));
},
None => {}
}
}
}*/
#[derive(Debug, Clone)]
struct Num8Node32 {
bits:[u64;4],
nodes:[Option<Box<Num8Node32>>;256],
val:Option<u32>,
depth:u32
}
impl Num8Node32 {
fn new(value:u32, depth_amnt:u32) -> Self {
Num8Node32 {bits: [0;4], nodes:std::array::from_fn(|_| None), val:Some(value), depth:depth_amnt}
}
fn insert(&mut self, key:u32, depth:u32) {
let first = ((key >> depth) & 0xff) as u64;
if !slc_test(&self.bits, first) {
slc_set(&mut self.bits, first);
self.nodes[first as usize] = Some(Box::new(Num8Node32::new(key, depth)));
return;
}
match self.val {
Some(v) => {
if key == v {
return;
}
let first_new = Num8Node32::new(v, self.depth + 8);
let second_new = Num8Node32::new(key, depth + 8);
self.val = None;
},
None => {}
}
}
}
fn main() {
let foo:u32 = 0b001100;
println!("{}", foo.leading_zeros());
println!("{}", foo.trailing_zeros());
println!("{:?}", get_eq_or_gt(0b1010100, 3));
println!("{:?}", get_eq_or_gt(0b1010100, 63));
println!("{:?}", get_eq_or_gt((1 as u64) << 63, 63));
println!("{:?}", get_eq_or_gt((1 as u64) << 63, 633));
println!("{:?}", get_eq_or_lt(0b1010100, 3));
println!("{:?}", get_eq_or_lt((1 as u64) << 63, 63));
println!("{:?}", get_eq_or_lt(0b1010101, 1));
println!("{:?}", get_eq_or_gt(0b10001010, 56));
let f = [0b10100100, 0b10001010, 0b10001111, 0b101111];
println!("{:?}", slc_get_eq_gt(&f, 120));
println!("{:?}", slc_get_eq_lt(&f, 176));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment