Skip to content

Instantly share code, notes, and snippets.

@monadplus
Created April 11, 2022 11:27
Show Gist options
  • Save monadplus/29110586e79bc75af6cc635539e913c3 to your computer and use it in GitHub Desktop.
Save monadplus/29110586e79bc75af6cc635539e913c3 to your computer and use it in GitHub Desktop.
Rust: indexed vectors
use std::{cmp::Ordering, fmt::Debug, marker::PhantomData, ops::Add};
trait Nat {
fn new() -> Self;
fn usize() -> usize;
}
#[derive(Copy, Clone, PartialEq, Eq)]
struct Z;
impl Nat for Z {
fn new() -> Self {
Z
}
fn usize() -> usize {
0
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
struct S<N>(PhantomData<N>);
impl<N: Nat> Nat for S<N> {
fn new() -> Self {
S(PhantomData)
}
fn usize() -> usize {
1 + N::usize()
}
}
impl std::fmt::Debug for Z {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "0")
}
}
impl<N: Nat> std::fmt::Debug for S<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Self::usize())
}
}
// TODO macro
type N0 = Z;
type N1 = S<Z>;
type N2 = S<N1>;
type N3 = S<N2>;
type N4 = S<N3>;
type N5 = S<N4>;
type N6 = S<N5>;
type N7 = S<N6>;
type N8 = S<N7>;
type N9 = S<N8>;
type Sum<L, R> = <L as Add<R>>::Output;
impl<N: Nat> Add<N> for Z {
type Output = N;
fn add(self, rhs: N) -> Self::Output {
rhs
}
}
impl<N, M> Add<N> for S<M>
where
N: Nat,
M: Nat + Add<N>,
{
type Output = S<Sum<M, N>>;
fn add(self, _: N) -> Self::Output {
S(PhantomData)
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
struct True;
#[derive(Copy, Clone, PartialEq, Eq)]
struct False;
trait Bool {
fn bool() -> bool;
}
impl Bool for True {
fn bool() -> bool {
true
}
}
impl Bool for False {
fn bool() -> bool {
false
}
}
trait IsTrue: Bool {}
impl IsTrue for True {}
trait IsFalse: Bool {}
impl IsFalse for False {}
trait Order {
fn ordering() -> Ordering;
}
trait Cmp<Other> {
type Output: Order;
}
type Compare<L, R> = <L as Cmp<R>>::Output;
#[derive(Copy, Clone, PartialEq, Eq)]
struct Less;
#[derive(Copy, Clone, PartialEq, Eq)]
struct Equal;
#[derive(Copy, Clone, PartialEq, Eq)]
struct Greater;
impl Order for Less {
fn ordering() -> Ordering {
Ordering::Less
}
}
impl Order for Equal {
fn ordering() -> Ordering {
Ordering::Equal
}
}
impl Order for Greater {
fn ordering() -> Ordering {
Ordering::Greater
}
}
impl Cmp<Z> for Z {
type Output = Equal;
}
impl<N: Nat> Cmp<S<N>> for Z {
type Output = Less;
}
impl<N: Nat> Cmp<Z> for S<N> {
type Output = Greater;
}
impl<M, N> Cmp<S<M>> for S<N>
where
M: Nat,
N: Nat + Cmp<M>,
{
type Output = <N as Cmp<M>>::Output;
}
impl<R: Nat> PartialOrd<R> for Z
where
Z: Cmp<R> + PartialEq<R>,
{
fn partial_cmp(&self, other: &R) -> Option<Ordering> {
Some(<Z as Cmp<R>>::Output::ordering())
}
}
impl<L: Nat, R: Nat> PartialOrd<R> for S<L>
where
S<L>: PartialEq<R>,
L: Cmp<R>,
{
fn partial_cmp(&self, other: &R) -> Option<Ordering> {
// This is not correct
Some(<L as Cmp<R>>::Output::ordering())
}
}
trait IsLess<Other> {
type Output: Bool;
}
impl<L> IsLess<Less> for L {
type Output = True;
}
impl<L> IsLess<Equal> for L {
type Output = False;
}
impl<L> IsLess<Greater> for L {
type Output = False;
}
trait LT<R> {
type Output: Bool;
}
impl<L, R> LT<R> for L
where
L: Cmp<R> + IsLess<Compare<L, R>>,
{
type Output = <L as IsLess<Compare<L, R>>>::Output;
}
type Lt<L, R> = <L as LT<R>>::Output;
fn if_less_than<N, M>()
where
N: Nat,
M: Nat,
N: Cmp<M> + IsLess<Compare<N, M>>,
Lt<N, M>: IsTrue,
{
}
#[derive(PartialEq, Eq)]
struct SVec<N, A> {
len: PhantomData<N>,
vec: Vec<A>,
}
impl<N, A: Debug> std::fmt::Debug for SVec<N, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.vec.fmt(f)
}
}
impl<A> SVec<Z, A> {
fn new() -> Self {
SVec {
len: PhantomData,
vec: Vec::new(),
}
}
}
impl<N: Nat, A> SVec<N, A> {
fn len(self) -> usize {
N::usize() as usize
}
fn push(self, a: A) -> SVec<S<N>, A> {
let mut v = self.vec;
v.push(a);
SVec {
len: PhantomData,
vec: v,
}
}
fn append<M: Nat>(mut self, other: SVec<M, A>) -> SVec<Sum<N, M>, A>
where
M: Nat,
N: Add<M>,
{
self.vec.extend(other.vec);
SVec {
len: PhantomData,
vec: self.vec,
}
}
fn get_type_annotated<I>(&self) -> &A
where
I: Nat + Cmp<N> + IsLess<Compare<I, N>>,
Lt<I, N>: IsTrue,
{
&self.vec[I::usize()]
}
fn get<I>(&self, _: I) -> &A
where
I: Nat + Cmp<N> + IsLess<Compare<I, N>>,
Lt<I, N>: IsTrue,
{
&self.vec[I::usize()]
}
}
#[macro_export]
macro_rules! svec {
() => { $crate::svec::SVec::new() };
( $($x:expr),+ ) => {{
$crate::svec::SVec::new()$(.push($x))+
}};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_as_u64() {
assert_eq!(N0::usize(), 0);
assert_eq!(N1::usize(), 1);
assert_eq!(N2::usize(), 2);
assert_eq!(N3::usize(), 3);
}
#[test]
fn test_sum() {
type Sum0 = Sum<N1, N0>;
type Sum1 = Sum<N0, N1>;
type Sum2 = Sum<N2, N2>;
type Sum3 = Sum<N3, N2>;
assert_eq!(Sum0::usize(), 1);
assert_eq!(Sum1::usize(), 1);
assert_eq!(Sum2::usize(), 4);
assert_eq!(Sum3::usize(), 5);
}
#[test]
fn test_compare() {
assert_eq!(Compare::<N0, N1>::ordering(), Ordering::Less);
assert_eq!(Compare::<N1, N0>::ordering(), Ordering::Greater);
assert_eq!(Compare::<N2, N2>::ordering(), Ordering::Equal);
assert_eq!(Compare::<N2, N3>::ordering(), Ordering::Less);
}
#[test]
fn test_less_than() {
assert_eq!(Lt::<N0, N0>::bool(), false);
assert_eq!(Lt::<N0, N1>::bool(), true);
assert_eq!(Lt::<N0, N9>::bool(), true);
assert_eq!(Lt::<N1, N0>::bool(), false);
assert_eq!(Lt::<N2, N0>::bool(), false);
assert_eq!(Lt::<N2, N1>::bool(), false);
if_less_than::<N1, N4>();
// if_less_than::<N3, N1>(); Does not compile as expected
}
#[test]
fn test_svec() {
let svec = svec!(0, 1).append(svec!(2, 3));
assert_eq!(svec, svec!(0, 1, 2, 3));
let x = svec.get_type_annotated::<N2>();
assert_eq!(*x, 2);
let x = svec.get(N2::new());
assert_eq!(*x, 2);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment