Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Created November 7, 2018 15:35
Show Gist options
  • Save mbillingr/f8405a224e3df21d4738016917d460fd to your computer and use it in GitHub Desktop.
Save mbillingr/f8405a224e3df21d4738016917d460fd to your computer and use it in GitHub Desktop.
Symbolic Match with Rust's type system
use std::ops::{Add, Mul};
use std::fmt::{self, Debug};
trait Expression: Sized + Clone {
fn add<Other: Expression>(self, rhs: Other) -> Addition<Self, Other> {
Addition {
lhs: self,
rhs,
}
}
fn mul<Other: Expression>(self, rhs: Other) -> Multiplication<Self, Other> {
Multiplication {
lhs: self,
rhs,
}
}
}
trait Differentiable: Expression {
type Output: Expression;
fn diff(&self, x: Variable) -> Self::Output;
}
#[derive(Copy, Clone)]
struct Constant(f64);
impl fmt::Debug for Constant {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl Expression for Constant {
}
impl Differentiable for Constant {
type Output = Constant;
fn diff(&self, x: Variable) -> Self::Output {
Constant(0.0)
}
}
#[derive(Copy, Clone)]
struct Variable(&'static str);
impl fmt::Debug for Variable {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl Expression for Variable {
}
impl Differentiable for Variable {
type Output = Constant;
fn diff(&self, x: Variable) -> Self::Output {
if self.0 == x.0 {
Constant(1.0)
} else {
Constant(0.0)
}
}
}
#[derive(Clone)]
struct Addition<L, R> {
lhs: L,
rhs: R,
}
impl<L: Debug, R: Debug> fmt::Debug for Addition<L, R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({:?} + {:?})", self.lhs, self.rhs)
}
}
impl<L: Expression, R: Expression> Expression for Addition<L, R> {
}
impl<L: Differentiable, R: Differentiable> Differentiable for Addition<L, R> {
type Output = Addition<L::Output, R::Output>;
fn diff(&self, x: Variable) -> Self::Output {
Addition {
lhs: self.lhs.diff(x),
rhs: self.rhs.diff(x),
}
}
}
#[derive(Clone)]
struct Multiplication<L, R> {
lhs: L,
rhs: R,
}
impl<L: Debug, R: Debug> fmt::Debug for Multiplication<L, R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({:?} * {:?})", self.lhs, self.rhs)
}
}
impl<L: Expression, R: Expression> Expression for Multiplication<L, R> {
}
impl<L: Differentiable, R: Differentiable> Differentiable for Multiplication<L, R> {
type Output = Addition<Multiplication<L::Output, R>, Multiplication<L, R::Output>>;
fn diff(&self, x: Variable) -> Self::Output {
let lhs = self.lhs.diff(x).mul(self.rhs.clone());
let rhs = self.lhs.clone().mul(self.rhs.diff(x));
Addition {
lhs, rhs
}
}
}
impl<E: Expression> Add<E> for Variable {
type Output = Addition<Variable, E>;
fn add(self, rhs: E) -> Self::Output {
Expression::add(self, rhs)
}
}
impl<E: Expression> Mul<E> for Variable {
type Output = Multiplication<Variable, E>;
fn mul(self, rhs: E) -> Self::Output {
Expression::mul(self, rhs)
}
}
impl<E: Expression, A: Expression, B: Expression> Add<E> for Addition<A, B> {
type Output = Addition<Addition<A, B>, E>;
fn add(self, rhs: E) -> Self::Output {
Expression::add(self, rhs)
}
}
impl<E: Expression, A: Expression, B: Expression> Add<E> for Multiplication<A, B> {
type Output = Addition<Multiplication<A, B>, E>;
fn add(self, rhs: E) -> Self::Output {
Expression::add(self, rhs)
}
}
fn main() {
let x = Variable("x");
let f = x * x + x * Constant(5.0);
println!("{:?}", f);
println!("{:?}", f.diff(x));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment