Created
April 11, 2022 11:27
-
-
Save monadplus/29110586e79bc75af6cc635539e913c3 to your computer and use it in GitHub Desktop.
Rust: indexed vectors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::{cmp::Ordering, fmt::Debug, marker::PhantomData, ops::Add}; | |
| trait Nat { | |
| fn new() -> Self; | |
| fn usize() -> usize; | |
| } | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct Z; | |
| impl Nat for Z { | |
| fn new() -> Self { | |
| Z | |
| } | |
| fn usize() -> usize { | |
| 0 | |
| } | |
| } | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct S<N>(PhantomData<N>); | |
| impl<N: Nat> Nat for S<N> { | |
| fn new() -> Self { | |
| S(PhantomData) | |
| } | |
| fn usize() -> usize { | |
| 1 + N::usize() | |
| } | |
| } | |
| impl std::fmt::Debug for Z { | |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
| write!(f, "0") | |
| } | |
| } | |
| impl<N: Nat> std::fmt::Debug for S<N> { | |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
| write!(f, "{}", Self::usize()) | |
| } | |
| } | |
| // TODO macro | |
| type N0 = Z; | |
| type N1 = S<Z>; | |
| type N2 = S<N1>; | |
| type N3 = S<N2>; | |
| type N4 = S<N3>; | |
| type N5 = S<N4>; | |
| type N6 = S<N5>; | |
| type N7 = S<N6>; | |
| type N8 = S<N7>; | |
| type N9 = S<N8>; | |
| type Sum<L, R> = <L as Add<R>>::Output; | |
| impl<N: Nat> Add<N> for Z { | |
| type Output = N; | |
| fn add(self, rhs: N) -> Self::Output { | |
| rhs | |
| } | |
| } | |
| impl<N, M> Add<N> for S<M> | |
| where | |
| N: Nat, | |
| M: Nat + Add<N>, | |
| { | |
| type Output = S<Sum<M, N>>; | |
| fn add(self, _: N) -> Self::Output { | |
| S(PhantomData) | |
| } | |
| } | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct True; | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct False; | |
| trait Bool { | |
| fn bool() -> bool; | |
| } | |
| impl Bool for True { | |
| fn bool() -> bool { | |
| true | |
| } | |
| } | |
| impl Bool for False { | |
| fn bool() -> bool { | |
| false | |
| } | |
| } | |
| trait IsTrue: Bool {} | |
| impl IsTrue for True {} | |
| trait IsFalse: Bool {} | |
| impl IsFalse for False {} | |
| trait Order { | |
| fn ordering() -> Ordering; | |
| } | |
| trait Cmp<Other> { | |
| type Output: Order; | |
| } | |
| type Compare<L, R> = <L as Cmp<R>>::Output; | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct Less; | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct Equal; | |
| #[derive(Copy, Clone, PartialEq, Eq)] | |
| struct Greater; | |
| impl Order for Less { | |
| fn ordering() -> Ordering { | |
| Ordering::Less | |
| } | |
| } | |
| impl Order for Equal { | |
| fn ordering() -> Ordering { | |
| Ordering::Equal | |
| } | |
| } | |
| impl Order for Greater { | |
| fn ordering() -> Ordering { | |
| Ordering::Greater | |
| } | |
| } | |
| impl Cmp<Z> for Z { | |
| type Output = Equal; | |
| } | |
| impl<N: Nat> Cmp<S<N>> for Z { | |
| type Output = Less; | |
| } | |
| impl<N: Nat> Cmp<Z> for S<N> { | |
| type Output = Greater; | |
| } | |
| impl<M, N> Cmp<S<M>> for S<N> | |
| where | |
| M: Nat, | |
| N: Nat + Cmp<M>, | |
| { | |
| type Output = <N as Cmp<M>>::Output; | |
| } | |
| impl<R: Nat> PartialOrd<R> for Z | |
| where | |
| Z: Cmp<R> + PartialEq<R>, | |
| { | |
| fn partial_cmp(&self, other: &R) -> Option<Ordering> { | |
| Some(<Z as Cmp<R>>::Output::ordering()) | |
| } | |
| } | |
| impl<L: Nat, R: Nat> PartialOrd<R> for S<L> | |
| where | |
| S<L>: PartialEq<R>, | |
| L: Cmp<R>, | |
| { | |
| fn partial_cmp(&self, other: &R) -> Option<Ordering> { | |
| // This is not correct | |
| Some(<L as Cmp<R>>::Output::ordering()) | |
| } | |
| } | |
| trait IsLess<Other> { | |
| type Output: Bool; | |
| } | |
| impl<L> IsLess<Less> for L { | |
| type Output = True; | |
| } | |
| impl<L> IsLess<Equal> for L { | |
| type Output = False; | |
| } | |
| impl<L> IsLess<Greater> for L { | |
| type Output = False; | |
| } | |
| trait LT<R> { | |
| type Output: Bool; | |
| } | |
| impl<L, R> LT<R> for L | |
| where | |
| L: Cmp<R> + IsLess<Compare<L, R>>, | |
| { | |
| type Output = <L as IsLess<Compare<L, R>>>::Output; | |
| } | |
| type Lt<L, R> = <L as LT<R>>::Output; | |
| fn if_less_than<N, M>() | |
| where | |
| N: Nat, | |
| M: Nat, | |
| N: Cmp<M> + IsLess<Compare<N, M>>, | |
| Lt<N, M>: IsTrue, | |
| { | |
| } | |
| #[derive(PartialEq, Eq)] | |
| struct SVec<N, A> { | |
| len: PhantomData<N>, | |
| vec: Vec<A>, | |
| } | |
| impl<N, A: Debug> std::fmt::Debug for SVec<N, A> { | |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
| self.vec.fmt(f) | |
| } | |
| } | |
| impl<A> SVec<Z, A> { | |
| fn new() -> Self { | |
| SVec { | |
| len: PhantomData, | |
| vec: Vec::new(), | |
| } | |
| } | |
| } | |
| impl<N: Nat, A> SVec<N, A> { | |
| fn len(self) -> usize { | |
| N::usize() as usize | |
| } | |
| fn push(self, a: A) -> SVec<S<N>, A> { | |
| let mut v = self.vec; | |
| v.push(a); | |
| SVec { | |
| len: PhantomData, | |
| vec: v, | |
| } | |
| } | |
| fn append<M: Nat>(mut self, other: SVec<M, A>) -> SVec<Sum<N, M>, A> | |
| where | |
| M: Nat, | |
| N: Add<M>, | |
| { | |
| self.vec.extend(other.vec); | |
| SVec { | |
| len: PhantomData, | |
| vec: self.vec, | |
| } | |
| } | |
| fn get_type_annotated<I>(&self) -> &A | |
| where | |
| I: Nat + Cmp<N> + IsLess<Compare<I, N>>, | |
| Lt<I, N>: IsTrue, | |
| { | |
| &self.vec[I::usize()] | |
| } | |
| fn get<I>(&self, _: I) -> &A | |
| where | |
| I: Nat + Cmp<N> + IsLess<Compare<I, N>>, | |
| Lt<I, N>: IsTrue, | |
| { | |
| &self.vec[I::usize()] | |
| } | |
| } | |
| #[macro_export] | |
| macro_rules! svec { | |
| () => { $crate::svec::SVec::new() }; | |
| ( $($x:expr),+ ) => {{ | |
| $crate::svec::SVec::new()$(.push($x))+ | |
| }}; | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::*; | |
| #[test] | |
| fn test_as_u64() { | |
| assert_eq!(N0::usize(), 0); | |
| assert_eq!(N1::usize(), 1); | |
| assert_eq!(N2::usize(), 2); | |
| assert_eq!(N3::usize(), 3); | |
| } | |
| #[test] | |
| fn test_sum() { | |
| type Sum0 = Sum<N1, N0>; | |
| type Sum1 = Sum<N0, N1>; | |
| type Sum2 = Sum<N2, N2>; | |
| type Sum3 = Sum<N3, N2>; | |
| assert_eq!(Sum0::usize(), 1); | |
| assert_eq!(Sum1::usize(), 1); | |
| assert_eq!(Sum2::usize(), 4); | |
| assert_eq!(Sum3::usize(), 5); | |
| } | |
| #[test] | |
| fn test_compare() { | |
| assert_eq!(Compare::<N0, N1>::ordering(), Ordering::Less); | |
| assert_eq!(Compare::<N1, N0>::ordering(), Ordering::Greater); | |
| assert_eq!(Compare::<N2, N2>::ordering(), Ordering::Equal); | |
| assert_eq!(Compare::<N2, N3>::ordering(), Ordering::Less); | |
| } | |
| #[test] | |
| fn test_less_than() { | |
| assert_eq!(Lt::<N0, N0>::bool(), false); | |
| assert_eq!(Lt::<N0, N1>::bool(), true); | |
| assert_eq!(Lt::<N0, N9>::bool(), true); | |
| assert_eq!(Lt::<N1, N0>::bool(), false); | |
| assert_eq!(Lt::<N2, N0>::bool(), false); | |
| assert_eq!(Lt::<N2, N1>::bool(), false); | |
| if_less_than::<N1, N4>(); | |
| // if_less_than::<N3, N1>(); Does not compile as expected | |
| } | |
| #[test] | |
| fn test_svec() { | |
| let svec = svec!(0, 1).append(svec!(2, 3)); | |
| assert_eq!(svec, svec!(0, 1, 2, 3)); | |
| let x = svec.get_type_annotated::<N2>(); | |
| assert_eq!(*x, 2); | |
| let x = svec.get(N2::new()); | |
| assert_eq!(*x, 2); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment