Created
June 3, 2025 07:01
-
-
Save kayceesrk/cd4d9802265799febc263f9c75e83276 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Reverse-mode Algorithmic differentiation using effect handlers. | |
Adapted from https://twitter.com/tiarkrompf/status/963314799521222656. | |
See https://openreview.net/forum?id=SJxJtYkPG for more information. *) | |
module F : sig | |
type t | |
val mk : float code -> t | |
val (+.) : t -> t -> t | |
val ( *. ) : t -> t -> t | |
val grad : (t -> t) -> float code -> float code | |
val grad2 : (t * t -> t) -> float code * float code -> (float * float) code | |
end = struct | |
open Effect | |
open Effect.Deep | |
type 'a sd = Sta of 'a | Dyn of 'a code | |
let dyn : float sd -> float code = function | |
Sta x -> .< x >. | |
| Dyn x -> x | |
type t = { v : float sd; mutable d : float sd } | |
let ( +@ ) x y = match x, y with | |
Sta 0.0, d | |
| d , Sta 0.0 -> d | |
| Sta x, Sta y -> Sta (x +. y) | |
| x, y -> Dyn .< .~(dyn x) +. .~(dyn y) >. | |
let ( *@ ) x y = match x, y with | |
Sta 0.0, _ | |
| _, Sta 0.0 -> Sta 0.0 | |
| Sta 1.0, d | |
| d, Sta 1.0 -> d | |
| Sta x, Sta y -> Sta (x +. y) | |
| x, y -> Dyn .< .~(dyn x) *. .~(dyn y) >. | |
let mk v = {v = Dyn v; d = Sta 0.0} | |
type _ Effect.t += Add : t * t -> t Effect.t | |
| Mult : t * t -> t Effect.t | |
let run f = | |
ignore (match f () with | |
| r -> r.d <- Sta 1.0; r | |
| effect (Add(a,b)), k -> | |
let x = {v = a.v +@ b.v; d = Sta 0.0} in | |
ignore (continue k x); | |
a.d <- a.d +@ x.d; | |
b.d <- b.d +@ x.d; | |
x | |
| effect (Mult(a,b)),k -> | |
let x = {v = a.v *@ b.v; d = Sta 0.0} in | |
ignore (continue k x); | |
a.d <- a.d +@ (b.v *@ x.d); | |
b.d <- b.d +@ (a.v *@ x.d); | |
x) | |
let grad f x = | |
let x = mk x in | |
run (fun () -> f x); | |
dyn x.d | |
let grad2 f (x, y) = | |
let x,y = mk x, mk y in | |
run (fun () -> f (x,y)); | |
.< .~(dyn x.d), .~(dyn y.d) >. | |
let (+.) a b = perform (Add(a,b)) | |
let ( *. ) a b = perform (Mult(a,b)) | |
end;; | |
(* f = x + x^3 => | |
df/dx = 1 + 3 * x^2 *) | |
let () = | |
Format.printf "d(x + x^3)/dx is %a\n" | |
Codelib.print_code | |
.< fun x -> .~(F.(grad (fun x -> x +. x *. x *. x) .<x>.)) >. | |
;; | |
(* f = x^2 + x^3 => | |
df/dx = 2*x + 3 * x^2 *) | |
let () = | |
Format.printf "d(x^2 + x^3)/dx is %a\n" | |
Codelib.print_code | |
.< fun x -> .~(F.(grad (fun x -> x *. x +. x *. x *. x) .<x>.)) >. | |
;; | |
(* f = x^2 * y^4 => | |
df/dx = 2 * x * y^4 | |
df/dy = 4 * x^2 * y^3 *) | |
let () = | |
Format.printf "d(x^2 * y^4)/dx, d(x^2 + y^4)/dy are %a\n" | |
Codelib.print_code | |
.< fun x y -> | |
.~(F.(grad2 (fun (x,y) -> x *. x *. y *. y *. y *. y) (.<x>.,.<y>.))) >. | |
;; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment