Last active
February 4, 2020 00:26
-
-
Save NicolasT/5623368 to your computer and use it in GitHub Desktop.
RWST for state machines in OCaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open Ocamlbuild_plugin;; | |
open Command;; | |
let ocamlfind_query pkg = | |
let cmd = Printf.sprintf "ocamlfind query %s" (Filename.quote pkg) in | |
Ocamlbuild_pack.My_unix.run_and_open cmd input_line;; | |
dispatch begin function | |
| After_rules -> | |
flag ["ocaml"; "pp"; "use_monad"] | |
(S[A(ocamlfind_query "monad-custom" ^ "/pa_monad.cmo")]); | |
| _ -> () | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Monoids, for the Writer part *) | |
module type MONOID = sig | |
type t | |
val mempty : t | |
val mappend : t -> t -> t | |
end | |
module ListM = functor(A : sig type t end) -> (struct | |
type t = A.t list | |
let mempty = [] | |
let mappend a b = a @ b | |
end : MONOID with type t = A.t list) | |
(* Some useful signatures *) | |
module type MONAD = sig | |
type 'a t | |
val bind : 'a t -> ('a -> 'b t) -> 'b t | |
val return : 'a -> 'a t | |
end | |
module type RWS = sig | |
include MONAD | |
type r | |
val ask : r t | |
type w | |
val tell : w -> unit t | |
type s | |
val get : s t | |
val put : s -> unit t | |
end | |
(* Implementation of RWST, the Reader/Writer/State Monad Transformer *) | |
module RWST = | |
functor(R : sig type t end) -> | |
functor(W : MONOID) -> | |
functor(S : sig type t end) -> | |
functor(M : MONAD) -> (struct | |
(* Monad *) | |
type 'a t = RWST of (R.t -> S.t -> ('a * S.t * W.t) M.t) | |
let unRWST (RWST f) = f | |
let return a = RWST (fun _ s -> M.return (a, s, W.mempty)) | |
let bind m k = RWST (fun r s -> perform with M.bind in | |
let f = unRWST m in | |
(a, s', w) <-- f r s; | |
let f' = unRWST (k a) in | |
(b, s'', w') <-- f' r s'; | |
M.return (b, s'', W.mappend w w')) | |
(* MonadReader *) | |
type r = R.t | |
let ask = RWST (fun r s -> M.return (r, s, W.mempty)) | |
(* MonadWriter *) | |
type w = W.t | |
let tell w = RWST (fun _ s -> M.return ((), s, w)) | |
(* MonadState *) | |
type s = S.t | |
let get = RWST (fun _ s -> M.return (s, s, W.mempty)) | |
let put s = RWST (fun _ _ -> M.return ((), s, W.mempty)) | |
(* MonadTrans *) | |
let lift m = RWST (fun _ s -> perform with M.bind in | |
a <-- m; | |
M.return (a, s, W.mempty)) | |
let runRWST (a : 'a t) (r : R.t) (s : S.t) = | |
let f = unRWST a in | |
f r s | |
end : sig | |
include RWS | |
val lift : 'a M.t -> 'a t | |
val runRWST : 'a t -> R.t -> S.t -> ('a * S.t * W.t) M.t | |
end with type r = R.t and type w = W.t and type s = S.t) | |
(* Lenses! *) | |
type ('a, 'b) lens = (('a -> 'b) * ('a -> 'b -> 'a)) | |
module LensUtils = functor(M: RWS) -> (struct | |
let view (g, _) = M.bind M.ask (fun c -> M.return (g c)) | |
let use (g, _) = M.bind M.get (fun s -> M.return (g s)) | |
let (@=) (_, s) v = M.bind M.get (fun t -> M.put (s t v)) | |
end : sig | |
val view : (M.r, 'b) lens -> 'b M.t | |
val use : (M.s, 'b) lens -> 'b M.t | |
val (@=) : (M.s, 'b) lens -> 'b -> unit M.t | |
end) | |
(* Application-specific datastructures *) | |
type config = { _configNodeId : string | |
; _configNodes : string list | |
; _configElectionTimeout : int | |
} | |
(* Lenses for config *) | |
let configNodeId : (config, string) lens = | |
let get c = c._configNodeId | |
and set c n = { c with _configNodeId = n } in | |
(get, set) | |
let configNodes : (config, string list) lens = | |
let get c = c._configNodes | |
and set c n = { c with _configNodes = n } in | |
(get, set) | |
let configElectionTimeout : (config, int) lens = | |
let get c = c._configElectionTimeout | |
and set c n = { c with _configElectionTimeout = n } in | |
(get, set) | |
type message = Accept of int | |
let string_of_message = function | |
Accept i -> Printf.sprintf "Accept %d" i | |
type command = Broadcast of message | |
| Send of (string * message) | |
| ResetElectionTimeout of int | |
let string_of_command = function | |
| Broadcast m -> Printf.sprintf "Broadcast %s" (string_of_message m) | |
| Send (n, m) -> Printf.sprintf "Send (%S, %s)" n (string_of_message m) | |
| ResetElectionTimeout i -> Printf.sprintf "ResetElectionTimeout %d" i | |
type event = Message of message | |
| ElectionTimeout | |
type slave_state = { _slaveI : int } | |
let string_of_slave_state s = Printf.sprintf "{ _slaveI = %d }" s._slaveI | |
type master_state = { _masterI : int } | |
let string_of_master_state s = Printf.sprintf "{ _masterI = %d }" s._masterI | |
type state = Slave of slave_state | |
| Master of master_state | |
let string_of_state = function | |
| Slave s -> Printf.sprintf "Slave %s" (string_of_slave_state s) | |
| Master s -> Printf.sprintf "Master %s" (string_of_master_state s) | |
module TransitionUtils = functor(S: sig type t end) -> functor(M : MONAD) -> (struct | |
module Transition = | |
RWST (struct type t = config end) | |
(ListM (struct type t = command end)) | |
(struct type t = S.t end) | |
(M) | |
include Transition | |
module LU = LensUtils(Transition) | |
include LU | |
let (>>=) = bind | |
let runTransition = runRWST | |
let broadcast m = tell [Broadcast m] | |
let send n m = tell [Send (n, m)] | |
(* We can combine things: fetch something from config, and use it to emit a | |
* command *) | |
let resetElectionTimeout = | |
view configElectionTimeout >>= fun t -> | |
tell [ResetElectionTimeout t] | |
let currentState = get | |
(* This is obviously a bogus implementation *) | |
let isMajority l = | |
view configNodes >>= fun nodes -> | |
let m = true in | |
return m | |
end : sig | |
include MONAD | |
type r | |
type w | |
type s | |
val (>>=) : 'a t -> ('a -> 'b t) -> 'b t | |
val view : (r, 'b) lens -> 'b t | |
val use : (s, 'b) lens -> 'b t | |
val (@=) : (s, 'b) lens -> 'b -> unit t | |
val broadcast : message -> unit t | |
val send : string -> message -> unit t | |
val resetElectionTimeout : unit t | |
val currentState : s t | |
val isMajority : string list -> bool t | |
val lift : 'a M.t -> 'a t | |
val runTransition : 'a t -> r -> s -> ('a * s * w) M.t | |
end with type r = config | |
and type w = command list | |
and type s = S.t) | |
module type HANDLER = sig | |
type 'a t | |
type 'a m | |
type s | |
val handle : event -> state t | |
val runTransition : 'a t -> config -> s -> ('a * s * command list) m | |
end | |
module Slave = functor(M : MONAD) -> (struct | |
type s = slave_state | |
let i = | |
let get s = s._slaveI | |
and set s i = { s with _slaveI = i } in | |
(get, set) | |
module TU = TransitionUtils(struct type t = s end)(M) | |
open TU | |
type 'a t = 'a TU.t | |
type 'a m = 'a M.t | |
let handle = function | |
| ElectionTimeout -> perform | |
i' <-- use i; | |
return (Master { _masterI = i' + 1 }) | |
| Message m -> match m with | |
Accept i' -> perform | |
i'' <-- use i; | |
if i'' > i' | |
then begin | |
perform resetElectionTimeout; | |
perform i @= i'' + 1; | |
perform broadcast (Accept i''); | |
i'' <-- use i; | |
return (Master { _masterI = i'' }) | |
end | |
else return (Slave { _slaveI = 0 }) | |
let runTransition = TU.runTransition | |
end : HANDLER with type s = slave_state and type 'a m = 'a M.t) | |
(* For some odd (well, demonstrational) reason, this one is not abstracted over | |
* some monad, but is hard-coded to Lwt | |
*) | |
module Master = (struct | |
type s = master_state | |
let i = | |
let get s = s._masterI | |
and set s i = { s with _masterI = i } in | |
(get, set) | |
module TU = TransitionUtils(struct type t = s end)(Lwt) | |
open TU | |
type 'a t = 'a TU.t | |
type 'a m = 'a Lwt.t | |
let handle = function | |
| ElectionTimeout -> perform | |
i' <-- use i; | |
return (Slave { _slaveI = i' }) | |
| Message m -> perform | |
i' <-- use i; | |
perform i @= i' + 5; | |
perform send "node0" (Accept 1); | |
(* Underlying monad is Lwt, so we can lift actions from it *) | |
j <-- lift (Lwt.return 4); | |
perform resetElectionTimeout; | |
i' <-- use i; | |
perform broadcast (Accept (i' + j)); | |
s <-- currentState; | |
return (Master s) | |
let runTransition = TU.runTransition | |
end : HANDLER with type s = master_state and type 'a m = 'a Lwt.t) | |
module Handle = struct | |
module S = Slave(Lwt) | |
let select (a, _, c) = Lwt.return (a, c) | |
let (>>=) = Lwt.bind | |
let handle cfg s evt = match s with | |
| Slave s' -> S.runTransition (S.handle evt) cfg s' >>= select | |
| Master s' -> Master.runTransition (Master.handle evt) cfg s' >>= select | |
end | |
;; | |
let cfg = { _configNodeId = "node0" | |
; _configNodes = ["node0"; "node1"] | |
; _configElectionTimeout = 10 | |
} | |
and state0 = Slave { _slaveI = 10 } | |
and event0 = Message (Accept 1) in | |
let (a, w) = Lwt_main.run (Handle.handle cfg state0 event0) in | |
Printf.printf "New state: %s\n" (string_of_state a); | |
print_endline "Commands:"; | |
List.iter (fun c -> Printf.printf " - %s\n" (string_of_command c)) w |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment