Skip to content

Instantly share code, notes, and snippets.

@jO-Osko
Last active May 5, 2020 10:52
Show Gist options
  • Save jO-Osko/bd51057145c54b1a41fcabfcd9cc8c6b to your computer and use it in GitHub Desktop.
Save jO-Osko/bd51057145c54b1a41fcabfcd9cc8c6b to your computer and use it in GitHub Desktop.
Random.init 2020
type num = float
type distr_t = bool
let one: num = 1.0
let half: num = 0.5
let zero: num = 0.0
module DiscreteTerm = struct
type 'a term
= ReturnT of (num * 'a)
| FlipT of ('a term * 'a term)
type t = distr_t term
let return x = ReturnT (one, x)
let flip () : t = FlipT (return true, return false)
let score r = ReturnT (r, false) (* Should be () *)
let rec (>>=) m f =
let rec scale s = function
| ReturnT (r, x) -> ReturnT (s *. r, x)
| FlipT (k_f, k_t) -> FlipT (scale s k_f, scale s k_t)
in
match m with
| ReturnT (r, x) -> scale r (f x)
| FlipT (k_f, k_t) -> FlipT (k_f >>= f, k_t >>= f)
effect Score: num -> t
effect Flip: unit -> t
end
module DiscreteEnum = struct
type 'a enum = Enum of (num * 'a)
type t = (distr_t enum) list
let return (x: distr_t) : t= [Enum (one, x)]
let flip () : t = [Enum (half, false); Enum (half, true)]
let flip_biased () : t = [Enum(one, false); Enum(zero, true);]
let score r = [Enum (r, false)] (* should be () *)
let (>>=) (xs: t) (f: distr_t -> t) : t =
let scale (r: num) (xss: t) : t = List.map (fun (Enum (s, y)) -> Enum (r *. s, y)) xss in
List.fold_right
(fun (Enum (r, x)) acc -> ((scale r (f x)) @ acc) )
xs []
effect Score: num -> t
effect Flip: unit -> t
end
(* Inference Transformation*)
(*
(DiscreteTerm.t -> DiscreteEnum.t!{ Dt.Score, Dt.Flip, De.Score, De.Flip }
as a handler:
(DiscreteTerm.t!{ |a } ==> DiscreteEnum.t!{Dt.Score, Dt.Flip, De.Score, De.Flip| a}
*)
let rec r_hand : (DiscreteTerm.t -> DiscreteEnum.t) = function
| DiscreteTerm.ReturnT (r, x) -> DiscreteEnum.(
let _ = Score r in
return x
)
| DiscreteTerm.FlipT (x_f, x_t) -> DiscreteEnum.(
let ff = perform (Flip ()) in
let fm (b: bool) : DiscreteEnum.t = if b then r_hand x_t else r_hand x_f in
ff >>= fm
)
let main () =
match (
match
let a = perform (DiscreteTerm.Flip ()) in
a
with
| x -> (
r_hand x
)
)
with
| x -> x
| effect (DiscreteEnum.Score r) k -> continue k (DiscreteEnum.score r)
| effect (DiscreteEnum.Flip ()) k -> continue k (DiscreteEnum.flip ())
(* | effect (DiscreteEnum.Flip ()) k -> continue k (DiscreteEnum.flip_biased ()) *)
| effect (DiscreteTerm.Score r) k -> continue k (DiscreteTerm.score r)
| effect (DiscreteTerm.Flip ()) k -> continue k (DiscreteTerm.flip ())
;;
main ();;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment