Last active
January 24, 2023 16:58
-
-
Save jumbojets/856d68bb6f4379fe5a0d09fbbcbde27f to your computer and use it in GitHub Desktop.
tagless final, const generics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trait Interp { | |
type Repr<const M: usize, const N: usize>; | |
fn lit<const M: usize, const N: usize>(mat: [[f32; N]; M]) -> Self::Repr<M, N>; | |
fn neg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<M, N>; | |
fn avg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<1, 1>; | |
fn add<const M: usize, const N: usize>( | |
l: Self::Repr<M, N>, | |
r: Self::Repr<M, N>, | |
) -> Self::Repr<M, N>; | |
fn matmul<const M: usize, const N: usize, const P: usize>( | |
l: Self::Repr<M, N>, | |
r: Self::Repr<N, P>, | |
) -> Self::Repr<M, P>; | |
fn clamp<const M: usize, const N: usize>( | |
x: Self::Repr<M, N>, | |
min: f32, | |
max: f32, | |
) -> Self::Repr<M, N>; | |
fn pow<const M: usize, const N: usize>(x: Self::Repr<M, N>, n: f32) -> Self::Repr<M, N>; | |
} | |
// in the abstract | |
fn relu<I: Interp, const M: usize, const N: usize>(x: I::Repr<M, N>) -> I::Repr<M, N> { | |
I::clamp(x, 0., f32::INFINITY) | |
} | |
fn layer<I: Interp, const M: usize, const N: usize>( | |
a: I::Repr<M, N>, | |
b: I::Repr<M, 1>, | |
x: I::Repr<N, 1>, | |
) -> I::Repr<M, 1> { | |
relu::<I, M, 1>(I::add(I::matmul(a, x), b)) | |
} | |
fn rmse<I: Interp, const M: usize>(y: I::Repr<M, 1>, x: I::Repr<M, 1>) -> I::Repr<1, 1> { | |
I::pow(I::avg(I::pow(I::add(y, I::neg(x)), 2.)), 0.5) | |
} | |
struct Eval; | |
impl Eval { | |
fn map<const M: usize, const N: usize>( | |
mut x: [[f32; N]; M], | |
op: impl Fn(f32) -> f32, | |
) -> [[f32; N]; M] { | |
for i in 0..M { | |
for j in 0..N { | |
x[i][j] = op(x[i][j]); | |
} | |
} | |
x | |
} | |
fn map2<const M: usize, const N: usize>( | |
mut x: [[f32; N]; M], | |
y: [[f32; N]; M], | |
op: impl Fn(f32, f32) -> f32, | |
) -> [[f32; N]; M] { | |
for i in 0..M { | |
for j in 0..N { | |
x[i][j] = op(x[i][j], y[i][j]); | |
} | |
} | |
x | |
} | |
} | |
impl Interp for Eval { | |
type Repr<const M: usize, const N: usize> = [[f32; N]; M]; | |
fn lit<const M: usize, const N: usize>(mat: [[f32; N]; M]) -> Self::Repr<M, N> { | |
mat | |
} | |
fn neg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<M, N> { | |
Self::map(x, |e| -e) | |
} | |
fn avg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<1, 1> { | |
[[x.iter().map(|row| row.iter().sum::<f32>()).sum::<f32>() / (M * N) as f32]] | |
} | |
fn add<const M: usize, const N: usize>( | |
x: Self::Repr<M, N>, | |
y: Self::Repr<M, N>, | |
) -> Self::Repr<M, N> { | |
Self::map2(x, y, |xe, ye| xe + ye) | |
} | |
fn matmul<const M: usize, const N: usize, const P: usize>( | |
x: Self::Repr<M, N>, | |
y: Self::Repr<N, P>, | |
) -> Self::Repr<M, P> { | |
let mut o = [[0.; P]; M]; | |
for i in 0..P { | |
for j in 0..M { | |
for k in 0..N { | |
o[j][i] += x[j][k] * y[k][i]; | |
} | |
} | |
} | |
o | |
} | |
fn clamp<const M: usize, const N: usize>( | |
x: Self::Repr<M, N>, | |
min: f32, | |
max: f32, | |
) -> Self::Repr<M, N> { | |
Self::map(x, |e| e.max(min).min(max)) | |
} | |
fn pow<const M: usize, const N: usize>(x: Self::Repr<M, N>, exp: f32) -> Self::Repr<M, N> { | |
Self::map(x, |e| f32::powf(e, exp)) | |
} | |
} | |
struct Display; | |
impl Interp for Display { | |
type Repr<const M: usize, const N: usize> = String; | |
fn lit<const M: usize, const N: usize>(_: [[f32; N]; M]) -> Self::Repr<M, N> { | |
format!("[{M}x{N}]") | |
} | |
fn neg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<M, N> { | |
format!("(-{x})") | |
} | |
fn avg<const M: usize, const N: usize>(x: Self::Repr<M, N>) -> Self::Repr<1, 1> { | |
format!("avg({x})") | |
} | |
fn add<const M: usize, const N: usize>( | |
l: Self::Repr<M, N>, | |
r: Self::Repr<M, N>, | |
) -> Self::Repr<M, N> { | |
format!("({l} + {r})") | |
} | |
fn matmul<const M: usize, const N: usize, const P: usize>( | |
l: Self::Repr<M, N>, | |
r: Self::Repr<N, P>, | |
) -> Self::Repr<M, P> { | |
format!("({l} @ {r})") | |
} | |
fn clamp<const M: usize, const N: usize>( | |
x: Self::Repr<M, N>, | |
min: f32, | |
max: f32, | |
) -> Self::Repr<M, N> { | |
format!("clamp({x}, {min}, {max})") | |
} | |
fn pow<const M: usize, const N: usize>(x: Self::Repr<M, N>, n: f32) -> Self::Repr<M, N> { | |
format!("({x} ^ {n})") | |
} | |
} | |
fn expr<I: Interp>() -> I::Repr<1, 1> { | |
let x = I::lit([[0.9], [0.3], [0.7]]); | |
let a = I::lit([[0.9, 0.5, 0.4], [0.3, 0.5, 0.01], [0.7, 0.9, 0.2]]); | |
let b = I::lit([[0.2], [0.4], [0.3]]); | |
let o1 = layer::<I, 3, 3>(a, b, x); | |
let a = I::lit([[0.9, 0.5, 0.4], [0.3, 0.5, 0.01], [0.7, 0.9, 0.2]]); | |
let b = I::lit([[0.2], [0.4], [0.3]]); | |
let o2 = layer::<I, 3, 3>(a, b, o1); | |
let y = I::lit([[0.5], [0.2], [0.9]]); | |
rmse::<I, 3>(o2, y) | |
} | |
fn main() { | |
dbg!(expr::<Eval>()); | |
dbg!(expr::<Display>()); | |
// TODO: some backwards interpretation of the expression?! | |
// dbg!(expr::<Backwards<Eval>>()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://github.com/coreylowman/dfdx