Skip to content

Instantly share code, notes, and snippets.

@rrbutani
Created November 29, 2022 11:03
Show Gist options
  • Save rrbutani/d484f67468c587eaf5c776467e9394aa to your computer and use it in GitHub Desktop.
Save rrbutani/d484f67468c587eaf5c776467e9394aa to your computer and use it in GitHub Desktop.
//! https://twitter.com/joseph_h_garvin/status/1597272949098438656
//! https://users.rust-lang.org/t/shrinking-bitset-with-compile-time-known-length/84244
//! https://rust.godbolt.org/z/9sevMvhsa
//! https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=5bb14fbfa4d086945e10107baf111679
#![allow(type_alias_bounds)]
// #![recursion_limit = "20"]
use core::{
any::type_name,
fmt::{self, Debug},
iter::FusedIterator,
marker::PhantomData as P,
mem::size_of,
ops::{Add, Range, Sub},
};
use num_traits::PrimInt;
use typenum::{
consts::{U0, U128, U16, U32, U64, U8},
Diff, False, Gr, IsGreater, IsLessOrEqual, LeEq, Sum, True, Unsigned,
};
#[doc(hidden)]
pub trait Compute {
type Result: ?Sized;
}
#[doc(hidden)]
pub struct Ternary<Cond, IfTrue: ?Sized, Else: ?Sized>(P<(Cond, P<IfTrue>, Else)>);
impl<T: ?Sized, F: ?Sized> Compute for Ternary<True, T, F> {
type Result = T;
}
impl<T: ?Sized, F: ?Sized> Compute for Ternary<False, T, F> {
type Result = F;
}
type Cond<C, T, F> = <Ternary<C, T, F> as Compute>::Result;
trait GetTypenumType {
type Result: Unsigned;
}
#[doc(hidden)]
pub trait GetStorage {
type Storage: PrimInt;
}
#[doc(hidden)]
pub struct TypenumTyToStorage<N: Unsigned>(P<N>);
type StorageForLen<N: Unsigned> = <TypenumTyToStorage<N> as GetStorage>::Storage;
type CmpU64<N> = Cond<LeEq<N, U64>, u64, u128>;
type CmpU32<N> = Cond<LeEq<N, U32>, u32, CmpU64<N>>;
type CmpU16<N> = Cond<LeEq<N, U16>, u16, CmpU32<N>>;
type CmpU8<N> = Cond<LeEq<N, U8>, u8, CmpU16<N>>;
impl<N: Unsigned> GetStorage for TypenumTyToStorage<N>
where
// N: IsLessOrEqual<U128, Output = True>, // it's up to the user of this trait to enforce this; we return 128 if N > 64
N: IsGreater<U0, Output = True>,
N: IsLessOrEqual<U64>,
Ternary<LeEq<N, U64>, u64, u128>: Compute,
CmpU64<N>: PrimInt,
N: IsLessOrEqual<U32>,
Ternary<LeEq<N, U32>, u32, CmpU64<N>>: Compute,
CmpU32<N>: PrimInt,
N: IsLessOrEqual<U16>,
Ternary<LeEq<N, U16>, u16, CmpU32<N>>: Compute,
CmpU16<N>: PrimInt,
N: IsLessOrEqual<U8>,
Ternary<LeEq<N, U8>, u8, CmpU16<N>>: Compute,
CmpU8<N>: PrimInt,
{
type Storage = CmpU8<N>;
}
pub trait BitSetStorageAccess {
fn try_get(&self, bit: usize) -> Result<bool, ()>; // TODO: error type
fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()>; // TODO: error type
#[inline(always)]
fn get(&self, bit: usize) -> bool {
self.try_get(bit).unwrap()
}
#[inline(always)]
fn set(&mut self, bit: usize, val: bool) {
self.try_set(bit, val).unwrap()
}
}
impl<S: BitSetStorageAccess> BitSetStorageAccess for &'_ S {
#[inline(always)]
fn try_get(&self, bit: usize) -> Result<bool, ()> {
S::try_get(self, bit)
}
#[inline(always)]
fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> {
Err(()) // yuck, TODO: split trait into read/write?
}
#[inline(always)]
fn get(&self, bit: usize) -> bool {
S::get(self, bit)
}
#[inline(always)]
fn set(&mut self, _: usize, _: bool) {
unimplemented!()
}
}
#[doc(hidden)]
#[derive(Debug, Default, Clone, Copy)]
pub struct Empty;
#[doc(hidden)]
pub type Sentinel<Len: Unsigned> = BitSetStorageNode<Empty, Len, U0, ()>;
impl BitSetStorageAccess for () {
fn try_get(&self, _: usize) -> Result<bool, ()> {
unreachable!()
}
fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> {
unreachable!()
}
fn get(&self, _: usize) -> bool {
unreachable!()
}
fn set(&mut self, _: usize, _: bool) {
unreachable!()
}
}
impl<L: Unsigned> BitSetStorageAccess for Sentinel<L> {
#[inline(always)]
fn try_get(&self, _: usize) -> Result<bool, ()> {
Err(())
}
#[inline(always)]
fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> {
Err(())
}
#[inline(always)]
fn get(&self, bit: usize) -> bool {
panic!(
"out of bounds: attempted to get index {bit} in a {} element bitset",
L::USIZE
)
}
#[inline(always)]
fn set(&mut self, bit: usize, _: bool) {
panic!(
"out of bounds: attempted to set index {bit} in a {} element bitset",
L::USIZE
)
}
}
#[doc(hidden)]
pub struct BitSetStorageNode<
Storage,
WidthOffset: Unsigned,
Len: Unsigned,
Rest: BitSetStorageAccess,
> {
inner: Storage,
rest: Rest,
_width: P<(WidthOffset, Len)>,
}
impl<S, O: Unsigned, L: Unsigned, Rest: BitSetStorageAccess + Debug> Debug
for BitSetStorageNode<S, O, L, Rest>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSetStorageNode")
.field("storage_ty", &type_name::<S>())
.field("range", &Self::range())
.field("rest", &self.rest)
.finish()
}
}
impl<S: Default, O: Unsigned, L: Unsigned, Rest: BitSetStorageAccess + Default> Default
for BitSetStorageNode<S, O, L, Rest>
{
fn default() -> Self {
Self {
inner: Default::default(),
rest: Default::default(),
_width: P,
}
}
}
impl<S, O: Unsigned, L: Unsigned, R: BitSetStorageAccess> BitSetStorageNode<S, O, L, R> {
#[inline(always)]
const fn range() -> Range<usize> {
debug_assert!(
(size_of::<S>() * 8) >= L::USIZE,
// "expected type {} (size = {} bytes) to have >= {} bits",
// type_name::<S>(),
// size_of::<S>(),
// L::USIZE
);
O::USIZE..(O::USIZE + L::USIZE)
}
#[inline(always)]
fn relative(bit: usize) -> Option<usize> {
if Self::range().contains(&bit) {
Some(bit - O::USIZE)
} else {
None
}
}
}
impl<Storage, Offs, Len, Rest> BitSetStorageAccess for BitSetStorageNode<Storage, Offs, Len, Rest>
where
Storage: PrimInt,
Offs: Unsigned,
Len: Unsigned,
Rest: BitSetStorageAccess,
{
#[inline(always)]
fn try_get(&self, bit: usize) -> Result<bool, ()> {
if let Some(bit_idx) = Self::relative(bit) {
let zero = Storage::zero();
let mask = Storage::one() << bit_idx;
Ok((self.inner & mask) != zero)
} else {
self.rest.try_get(bit)
}
}
fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()> {
if let Some(bit_idx) = Self::relative(bit) {
let val = if val { Storage::one() } else { Storage::zero() } << bit_idx;
self.inner = self.inner | val;
Ok(())
} else {
self.rest.try_set(bit, val)
}
}
}
#[doc(hidden)]
pub trait GetStorageNodes {
type Top: BitSetStorageAccess + Default;
}
// We prefer having fewer storage nodes over using as little space as possible
// here; i.e. we represent 80 bit bitsets as 1 `u128` instead of as a `u64` and
// a `u16`.
#[doc(hidden)]
pub struct LenToRootStorageNodeFewestNodes<Len: Unsigned, Offset: Unsigned = U0>((Len, Offset));
type FewestNodes<L: Unsigned, Offs = U0> =
<LenToRootStorageNodeFewestNodes<L, Offs> as GetStorageNodes>::Top;
#[doc(hidden)]
pub struct LenToRootStorageNodeFewestNodesRecurse<Len: Unsigned, Offset: Unsigned, GreaterThan128>(
(Len, Offset, GreaterThan128),
);
impl<L: Unsigned, O: Unsigned> GetStorageNodes
for LenToRootStorageNodeFewestNodesRecurse<L, O, False>
{
type Top = ();
}
impl<L: Unsigned, O: Unsigned> GetStorageNodes
for LenToRootStorageNodeFewestNodesRecurse<L, O, True>
where
O: Add<U128>,
Sum<O, U128>: Unsigned,
L: IsGreater<U128, Output = True>,
L: Sub<U128>,
Diff<L, U128>: Unsigned,
LenToRootStorageNodeFewestNodes<Diff<L, U128>, Sum<O, U128>>: GetStorageNodes,
{
type Top = FewestNodes<Diff<L, U128>, Sum<O, U128>>;
}
// if <= 128 bits, this is the final storage node:
type LastNode<O, L> = BitSetStorageNode<StorageForLen<L>, O, L, Sentinel<Sum<O, L>>>;
// if > 128 bits, add a 128 bit node and then recurse (with 128 subtracted from the length):
type RecurseNode<O, L> =
<LenToRootStorageNodeFewestNodesRecurse<L, O, Gr<L, U128>> as GetStorageNodes>::Top;
type StorageTop<O, L> =
Cond<Gr<L, U128>, BitSetStorageNode<u128, O, U128, RecurseNode<O, L>>, LastNode<O, L>>;
impl<L: Unsigned, O: Unsigned> GetStorageNodes for LenToRootStorageNodeFewestNodes<L, O>
where
L: IsGreater<U128>,
O: Add<L>,
Sum<O, L>: Unsigned,
LenToRootStorageNodeFewestNodesRecurse<L, O, Gr<L, U128>>: GetStorageNodes,
TypenumTyToStorage<L>: GetStorage,
LastNode<O, L>: BitSetStorageAccess,
Ternary<Gr<L, U128>, BitSetStorageNode<u128, O, U128, RecurseNode<O, L>>, LastNode<O, L>>:
Compute,
StorageTop<O, L>: BitSetStorageAccess + Sized + Default,
{
type Top = StorageTop<O, L>;
}
pub struct BitSet<Len: Unsigned, Storage: BitSetStorageAccess = FewestNodes<Len, U0>> {
inner: Storage,
_len: P<Len>,
}
impl<L: Unsigned, S: BitSetStorageAccess + Default> Default for BitSet<L, S> {
fn default() -> Self {
Self {
inner: Default::default(),
_len: Default::default(),
}
}
}
impl<L: Unsigned, S: BitSetStorageAccess> BitSetStorageAccess for BitSet<L, S> {
fn try_get(&self, bit: usize) -> Result<bool, ()> {
self.inner.try_get(bit)
}
fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()> {
self.inner.try_set(bit, val)
}
}
impl BitSet<U0, ()> {
pub fn new<L: Unsigned>() -> BitSet<L, FewestNodes<L>>
where
LenToRootStorageNodeFewestNodes<L, U0>: GetStorageNodes,
{
BitSet {
inner: Default::default(),
_len: Default::default(),
}
}
}
pub struct BitSetIterator<L: Unsigned, S: BitSetStorageAccess> {
inner: S,
curr_idx: usize,
end: usize,
_len: P<L>,
}
impl<L: Unsigned, S: BitSetStorageAccess> Iterator for BitSetIterator<L, S> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
if self.curr_idx == self.end
/* L::USIZE */
{
None
} else {
let res = Some(self.inner.get(self.curr_idx));
self.curr_idx += 1;
res
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.end - self.curr_idx;
(remaining, Some(remaining))
}
}
impl<L: Unsigned, S: BitSetStorageAccess> DoubleEndedIterator for BitSetIterator<L, S> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.curr_idx == self.end {
None
} else {
self.end -= 1;
Some(self.inner.get(self.end))
}
}
}
// unsafe impl<L: Unsigned, S: BitSetStorageAccess> TrustedLen for BitSetIterator<L, S> { }
impl<L: Unsigned, S: BitSetStorageAccess> ExactSizeIterator for BitSetIterator<L, S> {}
impl<L: Unsigned, S: BitSetStorageAccess> FusedIterator for BitSetIterator<L, S> {}
impl<L: Unsigned, S: BitSetStorageAccess> BitSet<L, S> {
pub fn iter(&self) -> impl Iterator<Item = bool> + '_ {
BitSetIterator::<L, &'_ Self> {
inner: self,
curr_idx: 0,
end: L::USIZE,
_len: P,
}
}
}
impl<L: Unsigned, S: BitSetStorageAccess> IntoIterator for BitSet<L, S> {
type Item = bool;
type IntoIter = BitSetIterator<L, S>;
fn into_iter(self) -> Self::IntoIter {
BitSetIterator {
inner: self.inner,
curr_idx: 0,
end: L::USIZE,
_len: P,
}
}
}
impl<L: Unsigned, S: BitSetStorageAccess> FromIterator<bool> for BitSet<L, S>
where
Self: Default,
{
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
let mut iter = iter.into_iter();
let mut out = Self::default();
for i in 0..L::USIZE {
out.set(i, iter.next().unwrap());
}
assert!(iter.next().is_none());
out
}
}
// trait ComputeConst<T> { const RES: T; }
/*
struct BoolToType<const B: bool>;
impl Compute for BoolToType<true> { type Result = True; }
impl Compute for BoolToType<false> { type Result = False; }
struct IsBitSet<const N: usize, const BIT: u32>;
impl<const N: usize, const BIT: u32> ComputeConst<bool> for IsBitSet<N, BIT> {
const RES: bool = {
(N & (1 << (BIT as usize))) != 0
};
}
type IsBitSetTy<const N: usize, const BIT: u32> = <
BoolToType<
{ <IsBitSet<{N}, {BIT}> as ComputeConst<bool>>::RES }
> as Compute
>::Result;
*/
/* struct Smuggle<const N: usize>;
impl<const N: usize> ComputeConst<usize> for Smuggle<N> {
const RES: usize = N;
}
struct BoolToType2<const B: bool>;
impl Compute for BoolToType2<true> { type Result = True; }
impl Compute for BoolToType2<false> { type Result = False; }
struct BoolToType<B: ComputeConst<bool>>;
impl<B: ComputeConst<bool>> Compute for BoolToType<B> { type Result = BoolToType2<{B::RES}>; }
// impl Compute for BoolToType<false> { type Result = False; }
struct IsBitSet<N: ComputeConst<usize>, B: ComputeConst<usize>>(P<(N, B)>);
impl<N: ComputeConst<usize>, B: ComputeConst<usize>> ComputeConst<bool> for IsBitSet<N, B> {
const RES: bool = {
(N::RES & (1 << (B::RES))) != 0
};
}
type IsBitSetTy<N: ComputeConst<usize>, B: ComputeConst<usize>> = <
BoolToType<
{ <IsBitSet<N, B> as ComputeConst<bool>>::RES }
> as Compute
>::Result;
*/
// struct IntToType<const N: usize>;
// // // this is the naïve implementation:
// // impl<const N: usize> IntToType<N>
// // where
// // IntToType<{ N - 1 }>: Sized,
// // {
// // }
macro_rules! print_ty {
($ty:ty) => {
println!("{}", type_name::<$ty>())
};
}
fn print_storage_nodes<N: Unsigned>()
where
LenToRootStorageNodeFewestNodes<N, U0>: GetStorageNodes,
FewestNodes<N>: Debug,
{
eprintln!("\nSize (bytes): {}", size_of::<FewestNodes<N>>());
eprintln!("{:#?}", FewestNodes::<N>::default())
}
fn main() {
print_ty!(FewestNodes<typenum::U1024>);
print_storage_nodes::<typenum::U7>();
print_storage_nodes::<typenum::U15>();
print_storage_nodes::<typenum::U31>();
print_storage_nodes::<typenum::U63>();
print_storage_nodes::<typenum::U64>();
print_storage_nodes::<typenum::U80>();
print_storage_nodes::<typenum::U1024>();
{
let mut b = BitSet::new::<typenum::U15>();
assert_eq!(b.try_get(15), Err(()));
assert_eq!(b.try_get(14), Ok(false));
assert_eq!(b.try_set(14, true), Ok(()));
assert_eq!(b.try_get(14), Ok(true));
}
{
let num: u64 = 0xDEAD_BEEF_0123_4567;
let b: BitSet<typenum::U64> = (0..64).map(|i| (num & (1 << i)) != 0).collect();
assert_eq!(b.inner.inner, num);
let out = b
.into_iter()
.rev()
.fold(0, |acc, b| (acc << 1) | (b as u64));
assert_eq!(out, num);
}
}
pub fn roundtrip(num: u64) {
let b: BitSet::<typenum::U64> = (0..64).map(|i| (num & (1 << i)) != 0).collect();
assert_eq!(b.inner.inner, num);
let out = b.into_iter().rev().fold(0, |acc, b| {
(acc << 1) | (b as u64)
});
assert_eq!(out, num);
}
pub fn accessor(set: &BitSet::<U64>) -> bool {
set.get(34)
}
pub fn accessor_panic(set: &BitSet::<typenum::U31>) -> bool {
set.get(34)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment