Last active
March 11, 2016 19:11
-
-
Save j0sh/ffcf56e6d45579715703 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* | |
pubsub semantics: | |
previous message is stored and sent to newly subscribing clients if present. | |
unpublishes will clear any previously stored messages | |
unsubscribes will push None to wake up any read loops | |
*) | |
type ('a, 'b) t = { | |
msg_tbl : ('a, 'b) Hashtbl.t; | |
sub_tbl : ('a, ('b option -> unit)) Hashtbl.t; | |
} | |
let create () : ('a, 'b) t = | |
let msg_tbl = Hashtbl.create 1000 in | |
let sub_tbl = Hashtbl.create 1000 in | |
{ msg_tbl; sub_tbl } | |
let sub ps chan = | |
let stream, push = Lwt_stream.create () in | |
Hashtbl.add ps.sub_tbl chan push; | |
if Hashtbl.mem ps.msg_tbl chan then push (Some (Hashtbl.find ps.msg_tbl chan)); | |
stream, push | |
let pub ps chan msg = | |
Hashtbl.replace ps.msg_tbl chan msg; | |
let msg = Some msg in | |
List.iter (fun p -> p msg) (Hashtbl.find_all ps.sub_tbl chan) | |
let unpub ps chan = | |
let msgs = Hashtbl.find_all ps.msg_tbl chan in | |
List.iter (fun _ -> Hashtbl.remove ps.msg_tbl chan) msgs | |
let unsub ps chan push = | |
(try push None with _ -> ()); (* wake up any sleeping readloops *) | |
(* this sucks *) | |
let subs = Hashtbl.find_all ps.sub_tbl chan in | |
List.iter (fun _ -> Hashtbl.remove ps.sub_tbl chan) subs; | |
let subs = List.filter (fun p -> p != push) subs in | |
List.iter (Hashtbl.add ps.sub_tbl chan) (List.rev subs) (* rev for stability *) | |
let peek ps = Hashtbl.find ps.msg_tbl | |
module PubSubTests = struct | |
open OUnit | |
let (>>=) = Lwt.bind | |
let list_cmp a b = | |
not (List.fold_left2 (fun acc a b -> a != b || acc) false a b) | |
let find ps = Hashtbl.find ps.msg_tbl | |
let find_all ps = Hashtbl.find_all ps.msg_tbl | |
let exists ps = Hashtbl.mem ps.msg_tbl | |
let test_pub () = | |
let ps = create () in | |
pub ps "foo" "bar" |> ignore; (* populate key *) | |
assert_equal "bar" (find ps "foo"); (* first test *) | |
pub ps "foo" "baz" |> ignore; (* overwrite key *) | |
assert_equal "baz" (find ps "foo"); (* verify overwrite *) | |
assert_equal ["baz"] (find_all ps "foo") (* verify single copy of value *) | |
let test_sub () = | |
let ps = create () in | |
pub ps "foo" "bar" |> ignore; | |
let (st, p1) = sub ps "foo" in | |
(* utility fn to read from a Lwt stream, assert and return a Lwt-unit *) | |
let sub_expects s v = | |
Lwt_stream.get s >>= function | |
| Some z -> assert_equal z v; Lwt.return_unit | |
| None -> assert_equal "empty publish" ""; Lwt.return_unit in | |
(* test subscribing to a nonexistent channel *) | |
let (_, p3) = sub ps "zug" in | |
assert_equal ~cmp:list_cmp [p3] (Hashtbl.find_all ps.sub_tbl "zug"); | |
(* test receiving value upon subscribe *) | |
let f = sub_expects st "bar" in | |
Lwt_main.run f; | |
(* test subscription of updated value *) | |
let f = sub_expects st "baz" in | |
let g () = Lwt.return (pub ps "foo" "baz") in | |
Lwt_main.run (Lwt.join [f; g ()]); | |
(* test multiple subscriptions to a value *) | |
let (st2, p2) = sub ps "foo" in | |
let f = sub_expects st2 "baz" in | |
Lwt_main.run f; (* consume the current value *) | |
(* FIXME to limit stream length to 1 and replace head if full *) | |
let f1 = sub_expects st "bam" in | |
let f2 = sub_expects st2 "bam" in | |
let g () = Lwt.return (pub ps "foo" "bam") in | |
Lwt_main.run (Lwt.join [f1; f2; g()]); | |
(* sanity check subs *) | |
assert_equal ~cmp:list_cmp [p2; p1] (Hashtbl.find_all ps.sub_tbl "foo") | |
let test_unsub () = | |
let ps = create () in | |
let cmp = list_cmp in | |
let hfa () = Hashtbl.find_all ps.sub_tbl "foo" in | |
let (_, p1) = sub ps "foo" in | |
let (_, p2) = sub ps "foo" in | |
let (_, p3) = sub ps "foo" in | |
(* sanity check subs *) | |
assert_equal ~cmp [p3; p2; p1] (hfa ()); | |
(* remove from middle *) | |
unsub ps "foo" p2; | |
assert_equal ~cmp [p3; p1] (hfa ()); | |
(* remove from end *) | |
unsub ps "foo" p1; | |
assert_equal ~cmp [p3] (hfa ()); | |
(* remove from head/first *) | |
unsub ps "foo" p3; | |
assert_equal ~cmp [] (hfa ()) | |
let test_unsub_push () = | |
let ps = create () in | |
(* verify that None gets pushed to break a loop after unsubscribing *) | |
let break_loop s = | |
Lwt_stream.get s >>= function | |
| Some _ -> assert_bool "expected None" false; Lwt.return_unit | |
| None -> assert_bool "ok" true; Lwt.return_unit in | |
let (st, p) = sub ps "foo" in | |
unsub ps "foo" p; | |
Lwt_main.run (break_loop st); | |
(* repeated unsubs should not raise an exception from push function *) | |
unsub ps "foo" p | |
let test_unpub () = | |
let ps = create () in | |
(* sanity check *) | |
pub ps "foo" "bar"; | |
assert_equal true (exists ps "foo"); | |
unpub ps "foo"; | |
assert_equal false (exists ps "foo"); | |
(* test write from previously empty *) | |
pub ps "foo" "baz"; | |
assert_equal ["baz"] (find_all ps "foo"); | |
(* test overwrite *) | |
pub ps "foo" "bam"; | |
assert_equal ["bam"] (find_all ps "foo"); | |
unpub ps "foo"; | |
assert_equal false (exists ps "foo"); | |
(* test nonexistent *) | |
unpub ps "foo"; | |
assert_equal false (exists ps "foo"); | |
() | |
let test_peek () = | |
let ps = create () in | |
(* test empty case *) | |
assert_raises ~msg:"peek empty" Not_found (fun () -> peek ps "foo"); | |
(* test element *) | |
pub ps "foo" "bar"; | |
assert_equal "bar" (peek ps "foo"); | |
(* test overwritten element *) | |
pub ps "foo" "baz"; | |
pub ps "foo" "bam"; | |
assert_equal "bam" (peek ps "foo"); | |
(* re-verify empty after unpub *) | |
unpub ps "foo"; | |
assert_raises ~msg:"unpub peek" Not_found (fun () -> peek ps "foo"); | |
() | |
let tests = [ | |
"pub">::test_pub; | |
"sub">::test_sub; | |
"unsub">::test_unsub; | |
"unsub_push">::test_unsub_push; | |
"unpub">::test_unpub; | |
"peek">::test_peek; | |
] | |
let run () = let _ = run_test_tt_main ("PubSubTests">:::tests) in () | |
end | |
(*let () = PubSubTests.run ()*) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment