Skip to content

Instantly share code, notes, and snippets.

@rjeli
Last active April 27, 2024 10:43
Show Gist options
  • Save rjeli/77cbf25406e1bf545cd97d0323d37f5c to your computer and use it in GitHub Desktop.
Save rjeli/77cbf25406e1bf545cd97d0323d37f5c to your computer and use it in GitHub Desktop.
towers lol
use std::{
marker::PhantomData,
ops::{Add, Mul},
};
trait Field: Copy + Default + Add<Output = Self> + Mul<Output = Self> {}
/*
We want a TowerField struct, such that it can be multiplied by itself or any subfield.
Furthermore, we would like to avoid macros (for fun) and try to obey orphan rule, so
external crates can define their own towers.
How do we get this behavior of being able to multiply by any subfield? Well, the only
way to get any conditional behavior at the type level in Rust is by matching on
non-overlapping impls:
struct A {}
struct B {}
trait Foo { type Out; }
impl Foo for A { type Out = u32; }
impl Foo for B { type Out = (); }
So how do we construct our tower? With a list.
struct Nil;
struct Cons<H, T>;
type Base_2_3 = Ext<Cons<Base, Cons<Alg_2, Cons<Alg_2_3, Nil>>>>;
The type goes from the base to the top of the tower. This constructs a
3-over-2-over-1 tower. `Base` is any field, and `Alg_2`, `Alg_2_3` are the tower algebras.
But how can we possibly add an Ext<Cons<Base, Nil>> to an Ext<Cons<Base, Cons<Alg_2, Cons<Alg_2_3, Nil>>>>?
First we check that the towers are compatible. Since we defined from the bottom up, this is easy:
The supertower list must be at least as long as the subtower, and all algebras up to that point are identical.
compat? super nil = true
compat? nil sub = false
compat? (h:t1) (h:t2) = compat? t1 t2
Ok, so we can easily check if compatible. Now let's modify it a bit so we return a new list, that
specifies the relationship.
compat? (h:t) nil = Smaller<h> : (compat? t nil)
compat? nil sub = !
compat? nil nil = nil
compat? (h:t1) (h:t2) = Same<h> : compat? t1 t2
This new compat? doesn't stop until the supertower is done, and pads the extra elements with "Smaller".
A B C
A
becomes
Same<A> Smaller<B> Smaller<C>
Now, we can implement ops (Add, Mul) on Same and Smaller, which operate on a Rhs argument of Same or Smaller size.
To call these ops, we simply need to reverse the op list.
*/
struct Nil;
// Cons, but with a const parameter. H (Head) is an algebra. T (Tail) can be super or subfield,
// depending on which way we're iterating.
struct Cons<H, const D: usize, T>(PhantomData<(H, T)>);
// Type-level list concat - only needed for Rev
type Concat<L, R> = <(L, R) as CanConcat>::Output;
trait CanConcat {
type Output;
}
// [] ++ xs = xs
impl<Rhs> CanConcat for (Nil, Rhs) {
type Output = Rhs;
}
// (h:t) ++ xs = h : (t ++ xs)
impl<H, const D: usize, T, Rhs> CanConcat for (Cons<H, D, T>, Rhs)
where
(T, Rhs): CanConcat,
{
type Output = Cons<H, D, Concat<T, Rhs>>;
}
// Type-level list reverse
type Rev<L> = <L as CanRev>::Output;
trait CanRev {
type Output;
}
// rev [] = []
impl CanRev for Nil {
type Output = Nil;
}
// rev (h:t) = (rev t) ++ [h]
impl<H, const D: usize, T> CanRev for Cons<H, D, T>
where
T: CanRev,
(Rev<T>, Cons<H, D, Nil>): CanConcat,
{
type Output = Concat<Rev<T>, Cons<H, D, Nil>>;
}
// Convert a type-list into a nested array with those degrees.
// e.g. Cons<_, 2, Cons<_, 3, Nil>> -> [[Elem; 3]; 2]
trait ToArr<Elem> {
type Output;
}
impl<Elem> ToArr<Elem> for Nil {
type Output = Elem;
}
impl<Elem, H, const D: usize, T> ToArr<Elem> for Cons<H, D, T>
where
T: ToArr<Elem>,
{
type Output = [T::Output; D];
}
// Since we list the tower from bottom to top, we need to reverse
// before creating the nested array. This is a helper to do that.
trait ExtRepr<Elem> {
type Output;
}
impl<Elem, List> ExtRepr<Elem> for List
where
List: CanRev,
Rev<List>: ToArr<Elem>,
{
type Output = <Rev<List> as ToArr<Elem>>::Output;
}
type Repr<Elem, T> = <T as ExtRepr<Elem>>::Output;
// Our binomial extension struct.
// Tower is a typelist of algebras.
// Array type is auto generated from tower list.
struct Ext<F, Tower: ExtRepr<F>>(Repr<F, Tower>);
trait Algebra<Base, const D: usize> {
const W: Base;
}
// Whether the Rhs extends up to this level or not.
struct Same<A>(PhantomData<A>);
struct Smaller;
// A type-list of Same or Smaller.
// Provides op methods which will automatically call down the rest of the list as is appropriate.
trait OpList<L, R> {
fn add(l: L, r: R) -> L;
fn mul(l: L, r: R) -> L;
}
// Bottom of the tower.
impl<F: Field> OpList<F, F> for Nil {
fn add(l: F, r: F) -> F {
l + r
}
fn mul(l: F, r: F) -> F {
l * r
}
}
// Ops for same sized input.
impl<A, SubRepr: Copy, const D: usize, T> OpList<[SubRepr; D], [SubRepr; D]> for Cons<Same<A>, D, T>
where
A: Algebra<SubRepr, D>,
T: OpList<SubRepr, SubRepr>,
[SubRepr; D]: Default,
{
fn add(l: [SubRepr; D], r: [SubRepr; D]) -> [SubRepr; D] {
core::array::from_fn(|i| T::add(l[i], r[i]))
}
fn mul(l: [SubRepr; D], r: [SubRepr; D]) -> [SubRepr; D] {
let mut res: [SubRepr; D] = Default::default();
for i in 0..D {
for j in 0..D {
let li_rj = T::mul(l[i], r[j]);
if i + j >= D {
res[i + j - D] = T::add(res[i + j - D], T::mul(A::W, li_rj));
} else {
res[i + j] = T::add(res[i + j], li_rj);
}
}
}
res
}
}
// Ops for when we haven't reached the subfield yet - just apply the next op to each of our elements.
impl<SubRepr1: Copy, SubRepr2: Copy, const D: usize, T> OpList<[SubRepr1; D], SubRepr2>
for Cons<Smaller, D, T>
where
T: OpList<SubRepr1, SubRepr2>,
{
fn add(l: [SubRepr1; D], r: SubRepr2) -> [SubRepr1; D] {
l.map(|x| T::add(x, r))
}
fn mul(l: [SubRepr1; D], r: SubRepr2) -> [SubRepr1; D] {
l.map(|x| T::mul(x, r))
}
}
trait ComputeOpList<Rhs> {
type Output;
}
// ops (h:t1) (h:t2) = Same<h> : (ops t1 t2)
impl<H, const D: usize, T1, T2> ComputeOpList<Cons<H, D, T2>> for Cons<H, D, T1>
where
T1: ComputeOpList<T2>,
{
type Output = Cons<Same<H>, D, T1::Output>;
}
// ops (h:t1) nil = Smaller<h> : (ops t1 nil)
impl<H, const D: usize, T> ComputeOpList<Nil> for Cons<H, D, T>
where
T: ComputeOpList<Nil>,
{
type Output = Cons<Smaller, D, T::Output>;
}
// ops nil nil = nil
impl ComputeOpList<Nil> for Nil {
type Output = Nil;
}
// Helper trait to compute the op list, reverse it, and constrain that the
// reversed op list is valid
trait SupertowerOf<SubT, F>
where
Self: ExtRepr<F>,
SubT: ExtRepr<F>,
{
type Ops: OpList<Repr<F, Self>, Repr<F, SubT>>;
}
impl<F, T, SubT> SupertowerOf<SubT, F> for T
where
T: ExtRepr<F>,
SubT: ExtRepr<F>,
T: ComputeOpList<SubT>,
<T as ComputeOpList<SubT>>::Output: CanRev,
Rev<<T as ComputeOpList<SubT>>::Output>: OpList<Repr<F, T>, Repr<F, SubT>>,
{
type Ops = Rev<<T as ComputeOpList<SubT>>::Output>;
}
// The payoff!!
impl<F, T: ExtRepr<F>, SubT: ExtRepr<F>> Add<Ext<F, SubT>> for Ext<F, T>
where
T: SupertowerOf<SubT, F>,
{
type Output = Ext<F, T>;
fn add(self, rhs: Ext<F, SubT>) -> Self::Output {
Ext(T::Ops::add(self.0, rhs.0))
}
}
impl<F, T: ExtRepr<F>, SubT: ExtRepr<F>> Mul<Ext<F, SubT>> for Ext<F, T>
where
T: SupertowerOf<SubT, F>,
{
type Output = Ext<F, T>;
fn mul(self, rhs: Ext<F, SubT>) -> Self::Output {
Ext(T::Ops::mul(self.0, rhs.0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
struct M7(u32);
impl From<u32> for M7 {
fn from(value: u32) -> Self {
Self(value % 127)
}
}
impl Add for M7 {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
(self.0 + rhs.0).into()
}
}
impl Mul for M7 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
(self.0 * rhs.0).into()
}
}
impl Field for M7 {}
struct M7_2_Alg;
type M7_2 = Ext<M7, Cons<M7_2_Alg, 2, Nil>>;
impl Algebra<M7, 2> for M7_2_Alg {
const W: M7 = M7(3);
}
struct M7_3_Alg;
type M7_3 = Ext<M7, Cons<M7_3_Alg, 3, Nil>>;
impl Algebra<M7, 3> for M7_3_Alg {
const W: M7 = M7(5);
}
struct M7_2_3_Alg;
type M7_2_3 = Ext<M7, Cons<M7_2_Alg, 2, Cons<M7_2_3_Alg, 3, Nil>>>;
impl Algebra<[M7; 2], 3> for M7_2_3_Alg {
const W: [M7; 2] = [M7(2), M7(3)];
}
struct M7_2_3_4_Alg;
type M7_2_3_4 = Ext<M7, Cons<M7_2_Alg, 2, Cons<M7_2_3_Alg, 3, Cons<M7_2_3_4_Alg, 4, Nil>>>>;
impl Algebra<[[M7; 2]; 3], 4> for M7_2_3_4_Alg {
const W: [[M7; 2]; 3] = todo!();
}
fn foo() {
let x: M7 = todo!();
let x_2: M7_2 = todo!();
let x_3: M7_3 = todo!();
let x_2_3: M7_2_3 = todo!();
let x_2_3_4: M7_2_3 = todo!();
// same
x_2 * x_2;
x_2_3 * x_2_3;
x_2_3_4 * x_2_3_4;
// smaller
x_2_3 * x_2;
x_2_3_4 * x_2;
x_2_3_4 * x_2_3;
// compiler errors:
/*
x_2 * x_3;
x_3 * x_2;
x_2_3 * x_3;
x_2_3_4 * x_3;
x_2 * x_2_3;
*/
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment