Created
April 22, 2025 06:47
-
-
Save kayceesrk/0468a072c2e585353da6832d2e2173b3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Reverse-mode Algorithmic differentiation using effect handlers. | |
Adapted from https://twitter.com/tiarkrompf/status/963314799521222656. | |
See https://openreview.net/forum?id=SJxJtYkPG for more information. *) | |
module F : sig | |
type t | |
val mk : float code -> t | |
val (+.) : t -> t -> t | |
val ( *. ) : t -> t -> t | |
val grad : (t -> t) -> float code -> float code | |
val grad2 : (t * t -> t) -> float code * float code -> (float * float) code | |
end = struct | |
type 'a sd = Sta of 'a | Dyn of 'a code | |
let dyn : float sd -> float code = function | |
Sta x -> .< x >. | |
| Dyn x -> x | |
type t = { v : float sd; mutable d : float sd } | |
let ( +@ ) x y = match x, y with | |
Sta 0.0, d | |
| d , Sta 0.0 -> d | |
| Sta x, Sta y -> Sta (x +. y) | |
| x, y -> Dyn .< .~(dyn x) +. .~(dyn y) >. | |
let ( *@ ) x y = match x, y with | |
Sta 0.0, _ | |
| _, Sta 0.0 -> Sta 0.0 | |
| Sta 1.0, d | |
| d, Sta 1.0 -> d | |
| Sta x, Sta y -> Sta (x +. y) | |
| x, y -> Dyn .< .~(dyn x) *. .~(dyn y) >. | |
let mk v = {v = Dyn v; d = Sta 0.0} | |
effect Add : t * t -> t | |
effect Mult : t * t -> t | |
let run f = | |
ignore (match f () with | |
| r -> r.d <- Sta 1.0; r | |
| effect (Add(a,b)) k -> | |
let x = {v = a.v +@ b.v; d = Sta 0.0} in | |
ignore (continue k x); | |
a.d <- a.d +@ x.d; | |
b.d <- b.d +@ x.d; | |
x | |
| effect (Mult(a,b)) k -> | |
let x = {v = a.v *@ b.v; d = Sta 0.0} in | |
ignore (continue k x); | |
a.d <- a.d +@ (b.v *@ x.d); | |
b.d <- b.d +@ (a.v *@ x.d); | |
x) | |
let grad f x = | |
let x = mk x in | |
run (fun () -> f x); | |
dyn x.d | |
let grad2 f (x, y) = | |
let x,y = mk x, mk y in | |
run (fun () -> f (x,y)); | |
.< .~(dyn x.d), .~(dyn y.d) >. | |
let (+.) a b = perform (Add(a,b)) | |
let ( *. ) a b = perform (Mult(a,b)) | |
end;; | |
(* f = x + x^3 => | |
df/dx = 1 + 3 * x^2 *) | |
let () = | |
Format.printf "d(x + x^3)/dx is %a\n" | |
Print_code.print_code | |
.< fun x -> .~(F.(grad (fun x -> x +. x *. x *. x) .<x>.)) >. | |
;; | |
(* f = x^2 + x^3 => | |
df/dx = 2*x + 3 * x^2 *) | |
let () = | |
Format.printf "d(x^2 + x^3)/dx is %a\n" | |
Print_code.print_code | |
.< fun x -> .~(F.(grad (fun x -> x *. x +. x *. x *. x) .<x>.)) >. | |
;; | |
(* f = x^2 * y^4 => | |
df/dx = 2 * x * y^4 | |
df/dy = 4 * x^2 * y^3 *) | |
let () = | |
Format.printf "d(x^2 * y^4)/dx, d(x^2 + y^4)/dy are %a\n" | |
Print_code.print_code | |
.< fun x y -> | |
.~(F.(grad2 (fun (x,y) -> x *. x *. y *. y *. y *. y) (.<x>.,.<y>.))) >. | |
;; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment