Skip to content

Instantly share code, notes, and snippets.

@kayceesrk
Created April 22, 2025 06:47
Show Gist options
  • Save kayceesrk/0468a072c2e585353da6832d2e2173b3 to your computer and use it in GitHub Desktop.
Save kayceesrk/0468a072c2e585353da6832d2e2173b3 to your computer and use it in GitHub Desktop.
(* Reverse-mode Algorithmic differentiation using effect handlers.
Adapted from https://twitter.com/tiarkrompf/status/963314799521222656.
See https://openreview.net/forum?id=SJxJtYkPG for more information. *)
module F : sig
type t
val mk : float code -> t
val (+.) : t -> t -> t
val ( *. ) : t -> t -> t
val grad : (t -> t) -> float code -> float code
val grad2 : (t * t -> t) -> float code * float code -> (float * float) code
end = struct
type 'a sd = Sta of 'a | Dyn of 'a code
let dyn : float sd -> float code = function
Sta x -> .< x >.
| Dyn x -> x
type t = { v : float sd; mutable d : float sd }
let ( +@ ) x y = match x, y with
Sta 0.0, d
| d , Sta 0.0 -> d
| Sta x, Sta y -> Sta (x +. y)
| x, y -> Dyn .< .~(dyn x) +. .~(dyn y) >.
let ( *@ ) x y = match x, y with
Sta 0.0, _
| _, Sta 0.0 -> Sta 0.0
| Sta 1.0, d
| d, Sta 1.0 -> d
| Sta x, Sta y -> Sta (x +. y)
| x, y -> Dyn .< .~(dyn x) *. .~(dyn y) >.
let mk v = {v = Dyn v; d = Sta 0.0}
effect Add : t * t -> t
effect Mult : t * t -> t
let run f =
ignore (match f () with
| r -> r.d <- Sta 1.0; r
| effect (Add(a,b)) k ->
let x = {v = a.v +@ b.v; d = Sta 0.0} in
ignore (continue k x);
a.d <- a.d +@ x.d;
b.d <- b.d +@ x.d;
x
| effect (Mult(a,b)) k ->
let x = {v = a.v *@ b.v; d = Sta 0.0} in
ignore (continue k x);
a.d <- a.d +@ (b.v *@ x.d);
b.d <- b.d +@ (a.v *@ x.d);
x)
let grad f x =
let x = mk x in
run (fun () -> f x);
dyn x.d
let grad2 f (x, y) =
let x,y = mk x, mk y in
run (fun () -> f (x,y));
.< .~(dyn x.d), .~(dyn y.d) >.
let (+.) a b = perform (Add(a,b))
let ( *. ) a b = perform (Mult(a,b))
end;;
(* f = x + x^3 =>
df/dx = 1 + 3 * x^2 *)
let () =
Format.printf "d(x + x^3)/dx is %a\n"
Print_code.print_code
.< fun x -> .~(F.(grad (fun x -> x +. x *. x *. x) .<x>.)) >.
;;
(* f = x^2 + x^3 =>
df/dx = 2*x + 3 * x^2 *)
let () =
Format.printf "d(x^2 + x^3)/dx is %a\n"
Print_code.print_code
.< fun x -> .~(F.(grad (fun x -> x *. x +. x *. x *. x) .<x>.)) >.
;;
(* f = x^2 * y^4 =>
df/dx = 2 * x * y^4
df/dy = 4 * x^2 * y^3 *)
let () =
Format.printf "d(x^2 * y^4)/dx, d(x^2 + y^4)/dy are %a\n"
Print_code.print_code
.< fun x y ->
.~(F.(grad2 (fun (x,y) -> x *. x *. y *. y *. y *. y) (.<x>.,.<y>.))) >.
;;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment