Skip to content

Instantly share code, notes, and snippets.

@jumbojets
Last active January 24, 2023 16:58
Show Gist options
  • Save jumbojets/856d68bb6f4379fe5a0d09fbbcbde27f to your computer and use it in GitHub Desktop.
Save jumbojets/856d68bb6f4379fe5a0d09fbbcbde27f to your computer and use it in GitHub Desktop.
tagless final, const generics
trait Interp {
type Repr<const M: usize, const N: usize>;
fn lit<const M: usize, const N: usize>(mat: [[f32; N]; M]) -> Self::Repr<M, N>;
fn neg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<M, N>;
fn avg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<1, 1>;
fn add<const M: usize, const N: usize>(
l: Self::Repr<M, N>,
r: Self::Repr<M, N>,
) -> Self::Repr<M, N>;
fn matmul<const M: usize, const N: usize, const P: usize>(
l: Self::Repr<M, N>,
r: Self::Repr<N, P>,
) -> Self::Repr<M, P>;
fn clamp<const M: usize, const N: usize>(
x: Self::Repr<M, N>,
min: f32,
max: f32,
) -> Self::Repr<M, N>;
fn pow<const M: usize, const N: usize>(x: Self::Repr<M, N>, n: f32) -> Self::Repr<M, N>;
}
// in the abstract
fn relu<I: Interp, const M: usize, const N: usize>(x: I::Repr<M, N>) -> I::Repr<M, N> {
I::clamp(x, 0., f32::INFINITY)
}
fn layer<I: Interp, const M: usize, const N: usize>(
a: I::Repr<M, N>,
b: I::Repr<M, 1>,
x: I::Repr<N, 1>,
) -> I::Repr<M, 1> {
relu::<I, M, 1>(I::add(I::matmul(a, x), b))
}
fn rmse<I: Interp, const M: usize>(y: I::Repr<M, 1>, x: I::Repr<M, 1>) -> I::Repr<1, 1> {
I::pow(I::avg(I::pow(I::add(y, I::neg(x)), 2.)), 0.5)
}
struct Eval;
impl Eval {
fn map<const M: usize, const N: usize>(
mut x: [[f32; N]; M],
op: impl Fn(f32) -> f32,
) -> [[f32; N]; M] {
for i in 0..M {
for j in 0..N {
x[i][j] = op(x[i][j]);
}
}
x
}
fn map2<const M: usize, const N: usize>(
mut x: [[f32; N]; M],
y: [[f32; N]; M],
op: impl Fn(f32, f32) -> f32,
) -> [[f32; N]; M] {
for i in 0..M {
for j in 0..N {
x[i][j] = op(x[i][j], y[i][j]);
}
}
x
}
}
impl Interp for Eval {
type Repr<const M: usize, const N: usize> = [[f32; N]; M];
fn lit<const M: usize, const N: usize>(mat: [[f32; N]; M]) -> Self::Repr<M, N> {
mat
}
fn neg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<M, N> {
Self::map(x, |e| -e)
}
fn avg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<1, 1> {
[[x.iter().map(|row| row.iter().sum::<f32>()).sum::<f32>() / (M * N) as f32]]
}
fn add<const M: usize, const N: usize>(
x: Self::Repr<M, N>,
y: Self::Repr<M, N>,
) -> Self::Repr<M, N> {
Self::map2(x, y, |xe, ye| xe + ye)
}
fn matmul<const M: usize, const N: usize, const P: usize>(
x: Self::Repr<M, N>,
y: Self::Repr<N, P>,
) -> Self::Repr<M, P> {
let mut o = [[0.; P]; M];
for i in 0..P {
for j in 0..M {
for k in 0..N {
o[j][i] += x[j][k] * y[k][i];
}
}
}
o
}
fn clamp<const M: usize, const N: usize>(
x: Self::Repr<M, N>,
min: f32,
max: f32,
) -> Self::Repr<M, N> {
Self::map(x, |e| e.max(min).min(max))
}
fn pow<const M: usize, const N: usize>(x: Self::Repr<M, N>, exp: f32) -> Self::Repr<M, N> {
Self::map(x, |e| f32::powf(e, exp))
}
}
struct Display;
impl Interp for Display {
type Repr<const M: usize, const N: usize> = String;
fn lit<const M: usize, const N: usize>(_: [[f32; N]; M]) -> Self::Repr<M, N> {
format!("[{M}x{N}]")
}
fn neg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<M, N> {
format!("(-{x})")
}
fn avg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<1, 1> {
format!("avg({x})")
}
fn add<const M: usize, const N: usize>(
l: Self::Repr<M, N>,
r: Self::Repr<M, N>,
) -> Self::Repr<M, N> {
format!("({l} + {r})")
}
fn matmul<const M: usize, const N: usize, const P: usize>(
l: Self::Repr<M, N>,
r: Self::Repr<N, P>,
) -> Self::Repr<M, P> {
format!("({l} @ {r})")
}
fn clamp<const M: usize, const N: usize>(
x: Self::Repr<M, N>,
min: f32,
max: f32,
) -> Self::Repr<M, N> {
format!("clamp({x}, {min}, {max})")
}
fn pow<const M: usize, const N: usize>(x: Self::Repr<M, N>, n: f32) -> Self::Repr<M, N> {
format!("({x} ^ {n})")
}
}
fn expr<I: Interp>() -> I::Repr<1, 1> {
let x = I::lit([[0.9], [0.3], [0.7]]);
let a = I::lit([[0.9, 0.5, 0.4], [0.3, 0.5, 0.01], [0.7, 0.9, 0.2]]);
let b = I::lit([[0.2], [0.4], [0.3]]);
let o1 = layer::<I, 3, 3>(a, b, x);
let a = I::lit([[0.9, 0.5, 0.4], [0.3, 0.5, 0.01], [0.7, 0.9, 0.2]]);
let b = I::lit([[0.2], [0.4], [0.3]]);
let o2 = layer::<I, 3, 3>(a, b, o1);
let y = I::lit([[0.5], [0.2], [0.9]]);
rmse::<I, 3>(o2, y)
}
fn main() {
dbg!(expr::<Eval>());
dbg!(expr::<Display>());
// TODO: some backwards interpretation of the expression?!
// dbg!(expr::<Backwards<Eval>>());
}
@jumbojets
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment