Created
November 29, 2022 11:03
-
-
Save rrbutani/d484f67468c587eaf5c776467e9394aa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! https://twitter.com/joseph_h_garvin/status/1597272949098438656 | |
//! https://users.rust-lang.org/t/shrinking-bitset-with-compile-time-known-length/84244 | |
//! https://rust.godbolt.org/z/9sevMvhsa | |
//! https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=5bb14fbfa4d086945e10107baf111679 | |
#![allow(type_alias_bounds)] | |
// #![recursion_limit = "20"] | |
use core::{ | |
any::type_name, | |
fmt::{self, Debug}, | |
iter::FusedIterator, | |
marker::PhantomData as P, | |
mem::size_of, | |
ops::{Add, Range, Sub}, | |
}; | |
use num_traits::PrimInt; | |
use typenum::{ | |
consts::{U0, U128, U16, U32, U64, U8}, | |
Diff, False, Gr, IsGreater, IsLessOrEqual, LeEq, Sum, True, Unsigned, | |
}; | |
#[doc(hidden)] | |
pub trait Compute { | |
type Result: ?Sized; | |
} | |
#[doc(hidden)] | |
pub struct Ternary<Cond, IfTrue: ?Sized, Else: ?Sized>(P<(Cond, P<IfTrue>, Else)>); | |
impl<T: ?Sized, F: ?Sized> Compute for Ternary<True, T, F> { | |
type Result = T; | |
} | |
impl<T: ?Sized, F: ?Sized> Compute for Ternary<False, T, F> { | |
type Result = F; | |
} | |
type Cond<C, T, F> = <Ternary<C, T, F> as Compute>::Result; | |
trait GetTypenumType { | |
type Result: Unsigned; | |
} | |
#[doc(hidden)] | |
pub trait GetStorage { | |
type Storage: PrimInt; | |
} | |
#[doc(hidden)] | |
pub struct TypenumTyToStorage<N: Unsigned>(P<N>); | |
type StorageForLen<N: Unsigned> = <TypenumTyToStorage<N> as GetStorage>::Storage; | |
type CmpU64<N> = Cond<LeEq<N, U64>, u64, u128>; | |
type CmpU32<N> = Cond<LeEq<N, U32>, u32, CmpU64<N>>; | |
type CmpU16<N> = Cond<LeEq<N, U16>, u16, CmpU32<N>>; | |
type CmpU8<N> = Cond<LeEq<N, U8>, u8, CmpU16<N>>; | |
impl<N: Unsigned> GetStorage for TypenumTyToStorage<N> | |
where | |
// N: IsLessOrEqual<U128, Output = True>, // it's up to the user of this trait to enforce this; we return 128 if N > 64 | |
N: IsGreater<U0, Output = True>, | |
N: IsLessOrEqual<U64>, | |
Ternary<LeEq<N, U64>, u64, u128>: Compute, | |
CmpU64<N>: PrimInt, | |
N: IsLessOrEqual<U32>, | |
Ternary<LeEq<N, U32>, u32, CmpU64<N>>: Compute, | |
CmpU32<N>: PrimInt, | |
N: IsLessOrEqual<U16>, | |
Ternary<LeEq<N, U16>, u16, CmpU32<N>>: Compute, | |
CmpU16<N>: PrimInt, | |
N: IsLessOrEqual<U8>, | |
Ternary<LeEq<N, U8>, u8, CmpU16<N>>: Compute, | |
CmpU8<N>: PrimInt, | |
{ | |
type Storage = CmpU8<N>; | |
} | |
pub trait BitSetStorageAccess { | |
fn try_get(&self, bit: usize) -> Result<bool, ()>; // TODO: error type | |
fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()>; // TODO: error type | |
#[inline(always)] | |
fn get(&self, bit: usize) -> bool { | |
self.try_get(bit).unwrap() | |
} | |
#[inline(always)] | |
fn set(&mut self, bit: usize, val: bool) { | |
self.try_set(bit, val).unwrap() | |
} | |
} | |
impl<S: BitSetStorageAccess> BitSetStorageAccess for &'_ S { | |
#[inline(always)] | |
fn try_get(&self, bit: usize) -> Result<bool, ()> { | |
S::try_get(self, bit) | |
} | |
#[inline(always)] | |
fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> { | |
Err(()) // yuck, TODO: split trait into read/write? | |
} | |
#[inline(always)] | |
fn get(&self, bit: usize) -> bool { | |
S::get(self, bit) | |
} | |
#[inline(always)] | |
fn set(&mut self, _: usize, _: bool) { | |
unimplemented!() | |
} | |
} | |
#[doc(hidden)] | |
#[derive(Debug, Default, Clone, Copy)] | |
pub struct Empty; | |
#[doc(hidden)] | |
pub type Sentinel<Len: Unsigned> = BitSetStorageNode<Empty, Len, U0, ()>; | |
impl BitSetStorageAccess for () { | |
fn try_get(&self, _: usize) -> Result<bool, ()> { | |
unreachable!() | |
} | |
fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> { | |
unreachable!() | |
} | |
fn get(&self, _: usize) -> bool { | |
unreachable!() | |
} | |
fn set(&mut self, _: usize, _: bool) { | |
unreachable!() | |
} | |
} | |
impl<L: Unsigned> BitSetStorageAccess for Sentinel<L> { | |
#[inline(always)] | |
fn try_get(&self, _: usize) -> Result<bool, ()> { | |
Err(()) | |
} | |
#[inline(always)] | |
fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> { | |
Err(()) | |
} | |
#[inline(always)] | |
fn get(&self, bit: usize) -> bool { | |
panic!( | |
"out of bounds: attempted to get index {bit} in a {} element bitset", | |
L::USIZE | |
) | |
} | |
#[inline(always)] | |
fn set(&mut self, bit: usize, _: bool) { | |
panic!( | |
"out of bounds: attempted to set index {bit} in a {} element bitset", | |
L::USIZE | |
) | |
} | |
} | |
#[doc(hidden)] | |
pub struct BitSetStorageNode< | |
Storage, | |
WidthOffset: Unsigned, | |
Len: Unsigned, | |
Rest: BitSetStorageAccess, | |
> { | |
inner: Storage, | |
rest: Rest, | |
_width: P<(WidthOffset, Len)>, | |
} | |
impl<S, O: Unsigned, L: Unsigned, Rest: BitSetStorageAccess + Debug> Debug | |
for BitSetStorageNode<S, O, L, Rest> | |
{ | |
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
f.debug_struct("BitSetStorageNode") | |
.field("storage_ty", &type_name::<S>()) | |
.field("range", &Self::range()) | |
.field("rest", &self.rest) | |
.finish() | |
} | |
} | |
impl<S: Default, O: Unsigned, L: Unsigned, Rest: BitSetStorageAccess + Default> Default | |
for BitSetStorageNode<S, O, L, Rest> | |
{ | |
fn default() -> Self { | |
Self { | |
inner: Default::default(), | |
rest: Default::default(), | |
_width: P, | |
} | |
} | |
} | |
impl<S, O: Unsigned, L: Unsigned, R: BitSetStorageAccess> BitSetStorageNode<S, O, L, R> { | |
#[inline(always)] | |
const fn range() -> Range<usize> { | |
debug_assert!( | |
(size_of::<S>() * 8) >= L::USIZE, | |
// "expected type {} (size = {} bytes) to have >= {} bits", | |
// type_name::<S>(), | |
// size_of::<S>(), | |
// L::USIZE | |
); | |
O::USIZE..(O::USIZE + L::USIZE) | |
} | |
#[inline(always)] | |
fn relative(bit: usize) -> Option<usize> { | |
if Self::range().contains(&bit) { | |
Some(bit - O::USIZE) | |
} else { | |
None | |
} | |
} | |
} | |
impl<Storage, Offs, Len, Rest> BitSetStorageAccess for BitSetStorageNode<Storage, Offs, Len, Rest> | |
where | |
Storage: PrimInt, | |
Offs: Unsigned, | |
Len: Unsigned, | |
Rest: BitSetStorageAccess, | |
{ | |
#[inline(always)] | |
fn try_get(&self, bit: usize) -> Result<bool, ()> { | |
if let Some(bit_idx) = Self::relative(bit) { | |
let zero = Storage::zero(); | |
let mask = Storage::one() << bit_idx; | |
Ok((self.inner & mask) != zero) | |
} else { | |
self.rest.try_get(bit) | |
} | |
} | |
fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()> { | |
if let Some(bit_idx) = Self::relative(bit) { | |
let val = if val { Storage::one() } else { Storage::zero() } << bit_idx; | |
self.inner = self.inner | val; | |
Ok(()) | |
} else { | |
self.rest.try_set(bit, val) | |
} | |
} | |
} | |
#[doc(hidden)] | |
pub trait GetStorageNodes { | |
type Top: BitSetStorageAccess + Default; | |
} | |
// We prefer having fewer storage nodes over using as little space as possible | |
// here; i.e. we represent 80 bit bitsets as 1 `u128` instead of as a `u64` and | |
// a `u16`. | |
#[doc(hidden)] | |
pub struct LenToRootStorageNodeFewestNodes<Len: Unsigned, Offset: Unsigned = U0>((Len, Offset)); | |
type FewestNodes<L: Unsigned, Offs = U0> = | |
<LenToRootStorageNodeFewestNodes<L, Offs> as GetStorageNodes>::Top; | |
#[doc(hidden)] | |
pub struct LenToRootStorageNodeFewestNodesRecurse<Len: Unsigned, Offset: Unsigned, GreaterThan128>( | |
(Len, Offset, GreaterThan128), | |
); | |
impl<L: Unsigned, O: Unsigned> GetStorageNodes | |
for LenToRootStorageNodeFewestNodesRecurse<L, O, False> | |
{ | |
type Top = (); | |
} | |
impl<L: Unsigned, O: Unsigned> GetStorageNodes | |
for LenToRootStorageNodeFewestNodesRecurse<L, O, True> | |
where | |
O: Add<U128>, | |
Sum<O, U128>: Unsigned, | |
L: IsGreater<U128, Output = True>, | |
L: Sub<U128>, | |
Diff<L, U128>: Unsigned, | |
LenToRootStorageNodeFewestNodes<Diff<L, U128>, Sum<O, U128>>: GetStorageNodes, | |
{ | |
type Top = FewestNodes<Diff<L, U128>, Sum<O, U128>>; | |
} | |
// if <= 128 bits, this is the final storage node: | |
type LastNode<O, L> = BitSetStorageNode<StorageForLen<L>, O, L, Sentinel<Sum<O, L>>>; | |
// if > 128 bits, add a 128 bit node and then recurse (with 128 subtracted from the length): | |
type RecurseNode<O, L> = | |
<LenToRootStorageNodeFewestNodesRecurse<L, O, Gr<L, U128>> as GetStorageNodes>::Top; | |
type StorageTop<O, L> = | |
Cond<Gr<L, U128>, BitSetStorageNode<u128, O, U128, RecurseNode<O, L>>, LastNode<O, L>>; | |
impl<L: Unsigned, O: Unsigned> GetStorageNodes for LenToRootStorageNodeFewestNodes<L, O> | |
where | |
L: IsGreater<U128>, | |
O: Add<L>, | |
Sum<O, L>: Unsigned, | |
LenToRootStorageNodeFewestNodesRecurse<L, O, Gr<L, U128>>: GetStorageNodes, | |
TypenumTyToStorage<L>: GetStorage, | |
LastNode<O, L>: BitSetStorageAccess, | |
Ternary<Gr<L, U128>, BitSetStorageNode<u128, O, U128, RecurseNode<O, L>>, LastNode<O, L>>: | |
Compute, | |
StorageTop<O, L>: BitSetStorageAccess + Sized + Default, | |
{ | |
type Top = StorageTop<O, L>; | |
} | |
pub struct BitSet<Len: Unsigned, Storage: BitSetStorageAccess = FewestNodes<Len, U0>> { | |
inner: Storage, | |
_len: P<Len>, | |
} | |
impl<L: Unsigned, S: BitSetStorageAccess + Default> Default for BitSet<L, S> { | |
fn default() -> Self { | |
Self { | |
inner: Default::default(), | |
_len: Default::default(), | |
} | |
} | |
} | |
impl<L: Unsigned, S: BitSetStorageAccess> BitSetStorageAccess for BitSet<L, S> { | |
fn try_get(&self, bit: usize) -> Result<bool, ()> { | |
self.inner.try_get(bit) | |
} | |
fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()> { | |
self.inner.try_set(bit, val) | |
} | |
} | |
impl BitSet<U0, ()> { | |
pub fn new<L: Unsigned>() -> BitSet<L, FewestNodes<L>> | |
where | |
LenToRootStorageNodeFewestNodes<L, U0>: GetStorageNodes, | |
{ | |
BitSet { | |
inner: Default::default(), | |
_len: Default::default(), | |
} | |
} | |
} | |
pub struct BitSetIterator<L: Unsigned, S: BitSetStorageAccess> { | |
inner: S, | |
curr_idx: usize, | |
end: usize, | |
_len: P<L>, | |
} | |
impl<L: Unsigned, S: BitSetStorageAccess> Iterator for BitSetIterator<L, S> { | |
type Item = bool; | |
fn next(&mut self) -> Option<Self::Item> { | |
if self.curr_idx == self.end | |
/* L::USIZE */ | |
{ | |
None | |
} else { | |
let res = Some(self.inner.get(self.curr_idx)); | |
self.curr_idx += 1; | |
res | |
} | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
let remaining = self.end - self.curr_idx; | |
(remaining, Some(remaining)) | |
} | |
} | |
impl<L: Unsigned, S: BitSetStorageAccess> DoubleEndedIterator for BitSetIterator<L, S> { | |
fn next_back(&mut self) -> Option<Self::Item> { | |
if self.curr_idx == self.end { | |
None | |
} else { | |
self.end -= 1; | |
Some(self.inner.get(self.end)) | |
} | |
} | |
} | |
// unsafe impl<L: Unsigned, S: BitSetStorageAccess> TrustedLen for BitSetIterator<L, S> { } | |
impl<L: Unsigned, S: BitSetStorageAccess> ExactSizeIterator for BitSetIterator<L, S> {} | |
impl<L: Unsigned, S: BitSetStorageAccess> FusedIterator for BitSetIterator<L, S> {} | |
impl<L: Unsigned, S: BitSetStorageAccess> BitSet<L, S> { | |
pub fn iter(&self) -> impl Iterator<Item = bool> + '_ { | |
BitSetIterator::<L, &'_ Self> { | |
inner: self, | |
curr_idx: 0, | |
end: L::USIZE, | |
_len: P, | |
} | |
} | |
} | |
impl<L: Unsigned, S: BitSetStorageAccess> IntoIterator for BitSet<L, S> { | |
type Item = bool; | |
type IntoIter = BitSetIterator<L, S>; | |
fn into_iter(self) -> Self::IntoIter { | |
BitSetIterator { | |
inner: self.inner, | |
curr_idx: 0, | |
end: L::USIZE, | |
_len: P, | |
} | |
} | |
} | |
impl<L: Unsigned, S: BitSetStorageAccess> FromIterator<bool> for BitSet<L, S> | |
where | |
Self: Default, | |
{ | |
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self { | |
let mut iter = iter.into_iter(); | |
let mut out = Self::default(); | |
for i in 0..L::USIZE { | |
out.set(i, iter.next().unwrap()); | |
} | |
assert!(iter.next().is_none()); | |
out | |
} | |
} | |
// trait ComputeConst<T> { const RES: T; } | |
/* | |
struct BoolToType<const B: bool>; | |
impl Compute for BoolToType<true> { type Result = True; } | |
impl Compute for BoolToType<false> { type Result = False; } | |
struct IsBitSet<const N: usize, const BIT: u32>; | |
impl<const N: usize, const BIT: u32> ComputeConst<bool> for IsBitSet<N, BIT> { | |
const RES: bool = { | |
(N & (1 << (BIT as usize))) != 0 | |
}; | |
} | |
type IsBitSetTy<const N: usize, const BIT: u32> = < | |
BoolToType< | |
{ <IsBitSet<{N}, {BIT}> as ComputeConst<bool>>::RES } | |
> as Compute | |
>::Result; | |
*/ | |
/* struct Smuggle<const N: usize>; | |
impl<const N: usize> ComputeConst<usize> for Smuggle<N> { | |
const RES: usize = N; | |
} | |
struct BoolToType2<const B: bool>; | |
impl Compute for BoolToType2<true> { type Result = True; } | |
impl Compute for BoolToType2<false> { type Result = False; } | |
struct BoolToType<B: ComputeConst<bool>>; | |
impl<B: ComputeConst<bool>> Compute for BoolToType<B> { type Result = BoolToType2<{B::RES}>; } | |
// impl Compute for BoolToType<false> { type Result = False; } | |
struct IsBitSet<N: ComputeConst<usize>, B: ComputeConst<usize>>(P<(N, B)>); | |
impl<N: ComputeConst<usize>, B: ComputeConst<usize>> ComputeConst<bool> for IsBitSet<N, B> { | |
const RES: bool = { | |
(N::RES & (1 << (B::RES))) != 0 | |
}; | |
} | |
type IsBitSetTy<N: ComputeConst<usize>, B: ComputeConst<usize>> = < | |
BoolToType< | |
{ <IsBitSet<N, B> as ComputeConst<bool>>::RES } | |
> as Compute | |
>::Result; | |
*/ | |
// struct IntToType<const N: usize>; | |
// // // this is the naïve implementation: | |
// // impl<const N: usize> IntToType<N> | |
// // where | |
// // IntToType<{ N - 1 }>: Sized, | |
// // { | |
// // } | |
macro_rules! print_ty { | |
($ty:ty) => { | |
println!("{}", type_name::<$ty>()) | |
}; | |
} | |
fn print_storage_nodes<N: Unsigned>() | |
where | |
LenToRootStorageNodeFewestNodes<N, U0>: GetStorageNodes, | |
FewestNodes<N>: Debug, | |
{ | |
eprintln!("\nSize (bytes): {}", size_of::<FewestNodes<N>>()); | |
eprintln!("{:#?}", FewestNodes::<N>::default()) | |
} | |
fn main() { | |
print_ty!(FewestNodes<typenum::U1024>); | |
print_storage_nodes::<typenum::U7>(); | |
print_storage_nodes::<typenum::U15>(); | |
print_storage_nodes::<typenum::U31>(); | |
print_storage_nodes::<typenum::U63>(); | |
print_storage_nodes::<typenum::U64>(); | |
print_storage_nodes::<typenum::U80>(); | |
print_storage_nodes::<typenum::U1024>(); | |
{ | |
let mut b = BitSet::new::<typenum::U15>(); | |
assert_eq!(b.try_get(15), Err(())); | |
assert_eq!(b.try_get(14), Ok(false)); | |
assert_eq!(b.try_set(14, true), Ok(())); | |
assert_eq!(b.try_get(14), Ok(true)); | |
} | |
{ | |
let num: u64 = 0xDEAD_BEEF_0123_4567; | |
let b: BitSet<typenum::U64> = (0..64).map(|i| (num & (1 << i)) != 0).collect(); | |
assert_eq!(b.inner.inner, num); | |
let out = b | |
.into_iter() | |
.rev() | |
.fold(0, |acc, b| (acc << 1) | (b as u64)); | |
assert_eq!(out, num); | |
} | |
} | |
pub fn roundtrip(num: u64) { | |
let b: BitSet::<typenum::U64> = (0..64).map(|i| (num & (1 << i)) != 0).collect(); | |
assert_eq!(b.inner.inner, num); | |
let out = b.into_iter().rev().fold(0, |acc, b| { | |
(acc << 1) | (b as u64) | |
}); | |
assert_eq!(out, num); | |
} | |
pub fn accessor(set: &BitSet::<U64>) -> bool { | |
set.get(34) | |
} | |
pub fn accessor_panic(set: &BitSet::<typenum::U31>) -> bool { | |
set.get(34) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment