Created
November 7, 2018 15:35
-
-
Save mbillingr/f8405a224e3df21d4738016917d460fd to your computer and use it in GitHub Desktop.
Symbolic Match with Rust's type system
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::ops::{Add, Mul}; | |
use std::fmt::{self, Debug}; | |
trait Expression: Sized + Clone { | |
fn add<Other: Expression>(self, rhs: Other) -> Addition<Self, Other> { | |
Addition { | |
lhs: self, | |
rhs, | |
} | |
} | |
fn mul<Other: Expression>(self, rhs: Other) -> Multiplication<Self, Other> { | |
Multiplication { | |
lhs: self, | |
rhs, | |
} | |
} | |
} | |
trait Differentiable: Expression { | |
type Output: Expression; | |
fn diff(&self, x: Variable) -> Self::Output; | |
} | |
#[derive(Copy, Clone)] | |
struct Constant(f64); | |
impl fmt::Debug for Constant { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
write!(f, "{:?}", self.0) | |
} | |
} | |
impl Expression for Constant { | |
} | |
impl Differentiable for Constant { | |
type Output = Constant; | |
fn diff(&self, x: Variable) -> Self::Output { | |
Constant(0.0) | |
} | |
} | |
#[derive(Copy, Clone)] | |
struct Variable(&'static str); | |
impl fmt::Debug for Variable { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
write!(f, "{}", self.0) | |
} | |
} | |
impl Expression for Variable { | |
} | |
impl Differentiable for Variable { | |
type Output = Constant; | |
fn diff(&self, x: Variable) -> Self::Output { | |
if self.0 == x.0 { | |
Constant(1.0) | |
} else { | |
Constant(0.0) | |
} | |
} | |
} | |
#[derive(Clone)] | |
struct Addition<L, R> { | |
lhs: L, | |
rhs: R, | |
} | |
impl<L: Debug, R: Debug> fmt::Debug for Addition<L, R> { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
write!(f, "({:?} + {:?})", self.lhs, self.rhs) | |
} | |
} | |
impl<L: Expression, R: Expression> Expression for Addition<L, R> { | |
} | |
impl<L: Differentiable, R: Differentiable> Differentiable for Addition<L, R> { | |
type Output = Addition<L::Output, R::Output>; | |
fn diff(&self, x: Variable) -> Self::Output { | |
Addition { | |
lhs: self.lhs.diff(x), | |
rhs: self.rhs.diff(x), | |
} | |
} | |
} | |
#[derive(Clone)] | |
struct Multiplication<L, R> { | |
lhs: L, | |
rhs: R, | |
} | |
impl<L: Debug, R: Debug> fmt::Debug for Multiplication<L, R> { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
write!(f, "({:?} * {:?})", self.lhs, self.rhs) | |
} | |
} | |
impl<L: Expression, R: Expression> Expression for Multiplication<L, R> { | |
} | |
impl<L: Differentiable, R: Differentiable> Differentiable for Multiplication<L, R> { | |
type Output = Addition<Multiplication<L::Output, R>, Multiplication<L, R::Output>>; | |
fn diff(&self, x: Variable) -> Self::Output { | |
let lhs = self.lhs.diff(x).mul(self.rhs.clone()); | |
let rhs = self.lhs.clone().mul(self.rhs.diff(x)); | |
Addition { | |
lhs, rhs | |
} | |
} | |
} | |
impl<E: Expression> Add<E> for Variable { | |
type Output = Addition<Variable, E>; | |
fn add(self, rhs: E) -> Self::Output { | |
Expression::add(self, rhs) | |
} | |
} | |
impl<E: Expression> Mul<E> for Variable { | |
type Output = Multiplication<Variable, E>; | |
fn mul(self, rhs: E) -> Self::Output { | |
Expression::mul(self, rhs) | |
} | |
} | |
impl<E: Expression, A: Expression, B: Expression> Add<E> for Addition<A, B> { | |
type Output = Addition<Addition<A, B>, E>; | |
fn add(self, rhs: E) -> Self::Output { | |
Expression::add(self, rhs) | |
} | |
} | |
impl<E: Expression, A: Expression, B: Expression> Add<E> for Multiplication<A, B> { | |
type Output = Addition<Multiplication<A, B>, E>; | |
fn add(self, rhs: E) -> Self::Output { | |
Expression::add(self, rhs) | |
} | |
} | |
fn main() { | |
let x = Variable("x"); | |
let f = x * x + x * Constant(5.0); | |
println!("{:?}", f); | |
println!("{:?}", f.diff(x)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment