Skip to content

Instantly share code, notes, and snippets.

@hirrolot
Last active October 1, 2024 16:54
Show Gist options
  • Save hirrolot/35e3c40e49e01cfb11d67b6bcc67b23e to your computer and use it in GitHub Desktop.
Save hirrolot/35e3c40e49e01cfb11d67b6bcc67b23e to your computer and use it in GitHub Desktop.
A complete implementation of the positive supercompiler from "A Roadmap to Metacomputation by Supercompilation" by Gluck & Sorensen

This is the predecessor of Mazeppa.

Supercompilation is a deep program transformation technique due to V. F. Turchin, a prominent computer scientist, cybernetician, physicist, and Soviet dissident. He described the concept as follows 1:

A supercompiler is a program transformer of a certain type. The usual way of thinking about program transformation is in terms of some set of rules which preserve the functional meaning of the program, and a step-by-step application of these rules to the initial program. ... The concept of a supercompiler is a product of cybernetic thinking. A program is seen as a machine. To make sense of it, one must observe its operation. So a supercompiler does not transform the program by steps; it controls and observes (SUPERvises) the running of the machine that is represented by the program; let us call this machine M1. In observing the operation of M1, the supercompiler COMPILES a program which describes the activities of M1, but it makes shortcuts and whatever clever tricks it knows in order to produce the same effect as M1, but faster. The goal of the supercompiler is to make the definition of this program (machine) M2 self-sufficient. When this is acheived, it outputs M2 in some intermediate language Lsup and simply throws away the (unchanged) machine M1.

A supercompiler is interesting not only as a program transformer but also as a very general philosophical concept:

The supercompiler concept comes close to the way humans think and make science. We do not think in terms of rules of formal logic. We create mental and linguistic models of the reality we observe. How do we do that? We observe phenomena, generalize observations, and try to construct a self-sufficient model in terms of these generalizations. This is also what the supercompiler does. ... A supercompiler would run M1, in a general form, with unknown values of variables, and create a graph of states and transitions between possible configurations of the computing system. ... To make it finite, the supercompiler performs the operation of generalization on the system configurations in such a manner that it finally comes to a set of generalized configurations, called basic, in terms of which the behavior of the system can be expressed. Thus the new program becomes a self-sufficient model of the old one.

Below is a complete OCaml implementation of the minimalistic supercompiler described by Gluck & Sorensen 2. The supercompiler deals with a first-order functional language called SLL (Simple Lazy Language).

There are several important differences between our supercompiler and the presentation from the paper:

  • We do not have the if-then-else form. It is not crucial for supercompilation.
  • We do not store terms in the graph. They are unnecessary because we can generate a residual program without them.
  • Instead of replacing an ancestor during generalization, we replace a child and continue building the graph. This simplifies the algorithm by avoiding graph rebuilding. This possibility is discussed in the paper.
  • We have a residual program generator. The paper mentions it but does not formalize it.

Overall, the implementation comprises ~300 LoC without the residuator (~220 without comments and blank lines) and ~420 LoC with the residuator (~310 without comments and blank lines).

Possible improvements include:

  • Using integers instead of string identifiers (faster symbol table access).
  • Implementing a user interface.
  • Implementing good error reporting with source code positions (consider using Menhir).
  • Implementing an alpha equivalence checker for graphs.
  • Implementing a compiler from SLL to C (can be useful for benchmarking).
  • Finding a sane way to do logging.

The full paper PDF can be accessed here.

TODO: write a newcomer-friendly blog post with complete explanations and link it here.

Show tests
let check ~equal ~show ~expected ~actual =
  if not (equal actual expected) then (
    Printf.eprintf "Expected:\n%s\n\nGot:\n%s\n" (show expected) (show actual);
    (* Prints the code location of the failure. *)
    assert false)

let print_subst fmt subst =
  subst |> Subst.bindings |> [%derive.show: (string * term) list]
  |> Format.pp_print_string fmt

type term_subst = (term Subst.t[@printer print_subst])
[@@deriving eq, show { with_path = false }]

type msg = term * term_subst * term_subst
[@@deriving eq, show { with_path = false }]

let test_subst () =
  let check ~expected (t, map) =
    var_counter := 0;
    check ~equal:equal_term ~show:show_term ~expected
      ~actual:(subst ~map:(Subst.of_list map) t)
  in
  check ~expected:(Var "y") (Var "x", [ ("x", Var "y") ]);
  check ~expected:(Var "x") (Var "x", [ ("z", Var "y") ]);
  check
    ~expected:(f_call ("f", [ Var "x"; Var "b"; Var "y" ]))
    ( f_call ("f", [ Var "a"; Var "b"; Var "c" ]),
      [ ("a", Var "x"); ("c", Var "y") ] )

let test_match_against () =
  let check ~expected (t1, t2) =
    var_counter := 0;
    check ~equal:[%derive.eq: term_subst option]
      ~show:[%derive.show: term_subst option] ~expected
      ~actual:(match_against (t1, t2))
  in
  check ~expected:(Some (Subst.of_list [ ("x", Var "x") ])) (Var "x", Var "x");
  check ~expected:(Some (Subst.of_list [ ("x", Var "y") ])) (Var "x", Var "y");
  (* Same variables in lhs. *)
  check
    ~expected:(Some (Subst.of_list [ ("x", Var "y") ]))
    (f_call ("f", [ Var "x"; Var "x" ]), f_call ("f", [ Var "y"; Var "y" ]));
  (* Unable to match. *)
  check ~expected:None
    (f_call ("f", [ Var "x"; Var "x" ]), f_call ("f", [ Var "y"; Var "z" ]));
  (* Match respective arguments. *)
  check
    ~expected:(Some (Subst.of_list [ ("x", Var "y"); ("a", Var "b") ]))
    (f_call ("f", [ Var "x"; Var "a" ]), f_call ("f", [ Var "y"; Var "b" ]));
  (* Different function names. *)
  check ~expected:None (f_call ("f", []), f_call ("f'", []));
  (* Different function kinds. *)
  check ~expected:None (f_call ("f", []), ctr_call ("f", []))

let test_reduce () =
  let check ?f_rules ?g_rules ~expected t =
    var_counter := 0;
    let program = create_program ?f_rules ?g_rules () in
    check ~equal:(equal_step equal_term) ~show:(show_step pp_term) ~expected
      ~actual:(reduce ~program t)
  in
  check ~expected:(Stop "x") (Var "x");
  check
    ~expected:(Decompose ("c", [ Var "x"; Var "y" ]))
    (ctr_call ("c", [ Var "x"; Var "y" ]));
  check
    ~f_rules:[ ("f", [ "x"; "y" ], ctr_call ("c", [ Var "x"; Var "y" ])) ]
    ~expected:(Transient (ctr_call ("c", [ Var "a"; Var "b" ])))
    (f_call ("f", [ Var "a"; Var "b" ]));
  let g_rules =
    [
      ( "g",
        [
          ( ("c", [ "p"; "q" ]),
            [ "x"; "y" ],
            ctr_call ("c", [ Var "p"; Var "q"; Var "x"; Var "y" ]) );
        ] );
    ]
  in
  check ~g_rules
    ~expected:
      (Transient (ctr_call ("c", [ Var "m"; Var "n"; Var "a"; Var "b" ])))
    (g_call ("g", ctr_call ("c", [ Var "m"; Var "n" ]), [ Var "a"; Var "b" ]));
  let g_rules =
    [
      ( "g",
        [
          ( ("c", [ "m"; "n" ]),
            [ "a"; "x" ],
            ctr_call ("p", [ Var "m"; Var "n"; Var "a"; Var "x" ]) );
          ( ("b", [ "m" ]),
            [ "a"; "x" ],
            ctr_call ("q", [ Var "m"; Var "a"; Var "x" ]) );
        ] );
    ]
  in
  check ~g_rules
    ~expected:
      (Variants
         ( "x",
           [
             ( { c = "b"; fresh_vars = [ "v0" ] },
               ctr_call
                 ("q", [ Var "v0"; Var "a"; ctr_call ("b", [ Var "v0" ]) ]) );
             ( { c = "c"; fresh_vars = [ "v1"; "v2" ] },
               ctr_call
                 ( "p",
                   [
                     Var "v1";
                     Var "v2";
                     Var "a";
                     ctr_call ("c", [ Var "v1"; Var "v2" ]);
                   ] ) );
           ] ))
    (g_call ("g", Var "x", [ Var "a"; Var "x" ]));
  (* Reduce an inner term (transient). *)
  check
    ~f_rules:[ ("f", [], ctr_call ("c", [])) ]
    ~expected:
      (Transient (g_call ("g", ctr_call ("c", []), [ Var "a"; Var "x" ])))
    (g_call ("g", f_call ("f", []), [ Var "a"; Var "x" ]));
  let g_rules =
    [
      ( "g",
        [
          (("c", [ "m"; "n" ]), [], ctr_call ("p", []));
          (("b", [ "m" ]), [], ctr_call ("q", []));
        ] );
    ]
  in
  (* Reduce an inner term (variants). *)
  check ~g_rules
    ~expected:
      (Variants
         ( "x",
           [
             ( { c = "b"; fresh_vars = [ "v0" ] },
               g_call
                 ( "g'",
                   ctr_call ("q", []),
                   [ Var "a"; ctr_call ("b", [ Var "v0" ]) ] ) );
             ( { c = "c"; fresh_vars = [ "v1"; "v2" ] },
               g_call
                 ( "g'",
                   ctr_call ("p", []),
                   [ Var "a"; ctr_call ("c", [ Var "v1"; Var "v2" ]) ] ) );
           ] ))
    (g_call
       (* The variable [x] must be substituted branch-wise. *)
       ("g'", g_call ("g", Var "x", []), [ Var "a"; Var "x" ]))

let test_he () =
  assert (decide_he (Var "x", Var "x"));
  assert (decide_he (Var "x", Var "y"));
  (* Coupling of respective arguments. *)
  assert (
    decide_he
      (f_call ("f", [ Var "x"; Var "y" ]), f_call ("f", [ Var "x"; Var "y" ])));
  (* Different function names. *)
  assert (not (decide_he (f_call ("f", []), f_call ("f'", []))));
  (* Different function kinds. *)
  assert (not (decide_he (f_call ("f", []), ctr_call ("f", []))));
  (* Incompatible arguments. *)
  assert (
    not
      (decide_he
         ( f_call ("f", [ f_call ("f", []) ]),
           f_call ("f", [ f_call ("f'", []) ]) )));
  (* Diving into at least one argument. *)
  assert (decide_he (Var "x", f_call ("f", [ Var "y"; ctr_call ("c", []) ])));
  (* Incompatible with the argument. *)
  assert (
    not (decide_he (ctr_call ("c", []), f_call ("f", [ ctr_call ("c'", []) ]))));
  (* A constructor cannot be embedded into a variable. *)
  assert (not (decide_he (ctr_call ("c", []), Var "x")))

let test_msg () =
  let check ~expected:(g, subst_1, subst_2) (t1, t2) =
    var_counter := 0;
    let expected = (g, Subst.of_list subst_1, Subst.of_list subst_2) in
    let actual = compute_msg (t1, t2) in
    check ~equal:equal_msg ~show:show_msg ~expected ~actual
  in
  (* A common functor. *)
  check
    ~expected:
      ( ctr_call ("c", [ Var "v1"; Var "v2" ]),
        [ ("v1", Var "x"); ("v2", Var "y") ],
        [ ("v1", Var "a"); ("v2", Var "b") ] )
    (ctr_call ("c", [ Var "x"; Var "y" ]), ctr_call ("c", [ Var "a"; Var "b" ]));
  (* A common substitution (after a common functor). *)
  check
    ~expected:
      ( ctr_call ("c", [ Var "v2"; Var "v2" ]),
        [ ("v2", Var "x") ],
        [ ("v2", Var "y") ] )
    (ctr_call ("c", [ Var "x"; Var "x" ]), ctr_call ("c", [ Var "y"; Var "y" ]))

let test_optimize () =
  let check ~expected ?f_rules ?g_rules t =
    check ~equal:[%derive.eq: string * string list]
      ~show:[%derive.show: string * string list] ~expected
      ~actual:(optimize ?f_rules ?g_rules t)
  in
  let rec num ?(hole = ctr_call ("Z", [])) n =
    if n = 0 then hole else ctr_call ("S", [ num ~hole (n - 1) ])
  in
  let add_rule =
    ( "add",
      [
        (("Z", []), [ "y" ], Var "y");
        ( ("S", [ "x" ]),
          [ "y" ],
          ctr_call ("S", [ g_call ("add", Var "x", [ Var "y" ]) ]) );
      ] )
  in
  (* Standard addition of Church numerals. *)
  let g_rules = [ add_rule ] in
  (* add(S(Z()), S(S(Z()))) -> S(S(S(Z()))) *)
  check ~expected:("S(S(S(Z())))", []) ~g_rules
    (g_call ("add", num 1, [ num 2 ]));
  (* add(S(S(Z())), b) -> S(S(b)) *)
  check ~expected:("S(S(b))", []) ~g_rules (g_call ("add", num 2, [ Var "b" ]));
  (* add(a, b) *)
  check
    ~expected:("g0(a, b)", [ "g0(S(v0), b) = S(g0(v0, b))"; "g0(Z(), b) = b" ])
    ~g_rules
    (g_call ("add", Var "a", [ Var "b" ]));
  (* Addition with an accumulator. *)
  let g_rules =
    [
      ( "addAcc",
        [
          (("Z", []), [ "y" ], Var "y");
          ( ("S", [ "x" ]),
            [ "y" ],
            g_call ("addAcc", Var "x", [ ctr_call ("S", [ Var "y" ]) ]) );
        ] );
    ]
  in
  (* addAcc(S(S(a)), b) *)
  check
    ~expected:
      ("g0(a, b)", [ "g0(S(v0), b) = g0(v0, S(b))"; "g0(Z(), b) = S(S(b))" ])
    ~g_rules
    (g_call ("addAcc", num ~hole:(Var "a") 2, [ Var "b" ]));
  (* addAcc(a, b) *)
  check
    ~expected:("g0(a, b)", [ "g0(S(v0), b) = g0(v0, S(b))"; "g0(Z(), b) = b" ])
    ~g_rules
    (g_call ("addAcc", Var "a", [ Var "b" ]));
  (* Multiplication via addition. *)
  let g_rules =
    [
      ( "mul",
        [
          (("Z", []), [ "y" ], ctr_call ("Z", []));
          ( ("S", [ "x" ]),
            [ "y" ],
            g_call ("add", g_call ("mul", Var "x", [ Var "y" ]), [ Var "y" ]) );
        ] );
      add_rule;
    ]
  in
  (* mul(S(S(Z())), S(S(S(Z())))) -> S(S(S(S(S(S(Z())))))) *)
  check
    ~expected:("S(S(S(S(S(S(Z()))))))", [])
    ~g_rules
    (g_call ("mul", num 2, [ num 3 ]));
  (* mul(S(S(a)), b) *)
  (* TODO: have a better understanding why this example causes code bloat. *)
  check
    ~expected:
      ( "g0(a, b)",
        [
          "f20(v20, v21) = f20(v20, v21)";
          "f28(v36, v37) = f28(v36, v37)";
          "g0(S(v0), b) = g3(g1(v0, b), b)";
          "g0(Z(), b) = g7(b)";
          "g1(S(v25), b) = g2(g1(b, v25), b)";
          "g1(Z(), b) = b";
          "g2(S(v29), v28) = S(g2(v29, v28))";
          "g2(Z(), v28) = v28";
          "g3(S(v6), v5) = g5(S(g4(v6, v5)), v5)";
          "g3(Z(), v5) = g6(v5)";
          "g4(S(v13), v5) = S(g4(v5, v13))";
          "g4(Z(), v5) = v5";
          "g5(S(v10), v9) = S(g5(v10, v9))";
          "g5(Z(), v9) = v9";
          "g6(S(v16)) = S(f20(v16, S(v16)))";
          "g6(Z()) = Z()";
          "g7(S(v32)) = S(f28(v32, S(v32)))";
          "g7(Z()) = Z()";
        ] )
    ~g_rules
    (g_call ("mul", num ~hole:(Var "a") 2, [ Var "b" ]));
  (* mul(a, b) *)
  check
    ~expected:
      ( "g0(a, b)",
        [
          "g0(S(v0), b) = g1(g0(v0, b), b)";
          "g0(Z(), b) = Z()";
          "g1(S(v4), v3) = S(g1(v4, v3))";
          "g1(Z(), v3) = v3";
        ] )
    ~g_rules
    (g_call ("mul", Var "a", [ Var "b" ]));
  (* The [eq] interpreter (to be specialized!). *)
  let g_rules =
    [
      ( "eq",
        [
          (("Z", []), [ "y" ], g_call ("eqZ", Var "y", []));
          (("S", [ "x" ]), [ "y" ], g_call ("eqS", Var "y", [ Var "x" ]));
        ] );
      ( "eqZ",
        [
          (("Z", []), [], ctr_call ("True", []));
          (("S", [ "x" ]), [], ctr_call ("False", []));
        ] );
      ( "eqS",
        [
          (("Z", []), [ "x" ], ctr_call ("False", []));
          (("S", [ "y" ]), [ "x" ], g_call ("eq", Var "x", [ Var "y" ]));
        ] );
    ]
  in
  (* The first Futamura projection: eq(S(S(Z())), x) *)
  check
    ~expected:
      ( "g0(x)",
        [
          "g0(S(v0)) = g1(v0)";
          "g0(Z()) = False()";
          "g1(S(v1)) = g2(v1)";
          "g1(Z()) = False()";
          "g2(S(v2)) = False()";
          "g2(Z()) = True()";
        ] )
    ~g_rules
    (g_call ("eq", num 2, [ Var "x" ]));
  let g_rules = add_rule :: g_rules in
  (* Interpretive inversion: eq(add(x, S(S(Z()))), S(S(S(Z())))) *)
  check
    ~expected:
      ( "g0(x)",
        [
          "g0(S(v0)) = g2(S(g1(v0)))";
          "g0(Z()) = False()";
          "g1(S(v11)) = S(g1(v11))";
          "g1(Z()) = S(S(Z()))";
          "g2(S(v7)) = g3(v7)";
          "g2(Z()) = False()";
          "g3(S(v8)) = g4(v8)";
          "g3(Z()) = False()";
          "g4(S(v9)) = g5(v9)";
          "g4(Z()) = False()";
          "g5(S(v10)) = False()";
          "g5(Z()) = True()";
        ] )
    ~g_rules
    (g_call ("eq", g_call ("add", Var "x", [ num 2 ]), [ num 3 ]));
  (* Theorem proving: eq(add(x, Z()), x) *)
  (* It is actually hard to call "theorem proving" due to the obscurity of the
     solution. A smarter supercompiler would be apt here. *)
  check
    ~expected:
      ( "g0(x)",
        [
          "g0(S(v0)) = g2(S(g1(v0)), S(v0))";
          "g0(Z()) = True()";
          "g1(S(v7)) = S(g1(v7))";
          "g1(Z()) = Z()";
          "g2(S(v4), v3) = g3(v3, v4)";
          "g2(Z(), v3) = g4(v3)";
          "g3(S(v5), v4) = g2(v4, v5)";
          "g3(Z(), v4) = False()";
          "g4(S(v6)) = False()";
          "g4(Z()) = True()";
        ] )
    ~g_rules
    (g_call ("eq", g_call ("add", Var "x", [ num 0 ]), [ Var "x" ]));
  (* List mapping. *)
  let f_rules = [ ("f", [ "x" ], ctr_call ("Blah", [ Var "x" ])) ] in
  let g_rules =
    [
      ( "map",
        [
          (("Nil", []), [], ctr_call ("Nil", []));
          ( ("Cons", [ "x"; "xs" ]),
            [],
            ctr_call
              ( "Cons",
                [ f_call ("f", [ Var "x" ]); g_call ("map", Var "xs", []) ] ) );
        ] );
    ]
  in
  (* Not very interesting: map(map(xs)) *)
  (* Our implementation does not perform list fusion because the standard
     homeomorphic embedding is too restrictive. The SPSC supercompiler [1]
     implements a more flexible approach to HE, thereby allowing it to transform
     this example into a one-pass algorithm.

     [1]: https://github.com/sergei-romanenko/spsc
  *)
  check
    ~expected:
      ( "g0(xs)",
        [
          "g0(Cons(v0, v1)) = g2(Cons(Blah(v0), g1(v1)))";
          "g0(Nil()) = Nil()";
          "g1(Cons(v9, v10)) = Cons(Blah(v9), g1(v10))";
          "g1(Nil()) = Nil()";
          "g2(Cons(v4, v5)) = Cons(Blah(v4), g2(v5))";
          "g2(Nil()) = Nil()";
        ] )
    ~f_rules ~g_rules
    (g_call ("map", g_call ("map", Var "xs", []), []));
  (* An infinite loop of transitions. *)
  let f_rules =
    [
      ("f", [ "x" ], f_call ("f'", [ Var "x" ]));
      ("f'", [ "x" ], f_call ("f''", [ Var "x" ]));
      ("f''", [ "x" ], f_call ("f", [ Var "x" ]));
    ]
  in
  (* f(x) *)
  check
    ~expected:("f0(x)", [ "f0(x) = f0(x)" ])
    ~f_rules
    (f_call ("f", [ Var "x" ]))

let () =
  test_subst ();
  test_match_against ();
  test_reduce ();
  test_he ();
  test_msg ();
  test_optimize ()

let _ = (equal_graph, pp_graph)
Show the `dune` file
(executable
 (public_name supercomp)
 (name main)
 (libraries supercomp)
 (preprocess
  (pps ppx_deriving.show ppx_deriving.eq))
 (instrumentation
  (backend bisect_ppx)))

(env
 (release
  (ocamlopt_flags
   (:standard -O3))))

Footnotes

  1. Valentin F. Turchin. 1986. The concept of a supercompiler. ACM Trans. Program. Lang. Syst. 8, 3 (July 1986), 292–325. https://doi.org/10.1145/5956.5957

  2. Robert Glück and Morten Heine Sørensen. 1996. A Roadmap to Metacomputation by Supercompilation. In Selected Papers from the International Seminar on Partial Evaluation. Springer-Verlag, Berlin, Heidelberg, 137–160.

(******************************************************************************)
(* AST definitions. *)
type call_kind = CtrKind | FKind | GKind
and term = Var of string | Call of call_kind * string * term list
[@@deriving eq, show { with_path = false }]
let var x = Var x
let ctr_call (c, args) = Call (CtrKind, c, args)
let f_call (f, args) = Call (FKind, f, args)
let g_call (g, scrutinee, args) = Call (GKind, g, scrutinee :: args)
(******************************************************************************)
(* Search tables for f- and g-functions. *)
module String_map = Map.Make (String)
module F_rules = String_map
module G_rules_by_name = String_map
module G_rules_by_pattern = String_map
type param_list = string list
type program = {
f_rules : (param_list * term) F_rules.t;
g_rules :
(param_list * param_list * term) G_rules_by_pattern.t G_rules_by_name.t;
}
let find_exn ~key ~on_none map =
match String_map.find_opt key map with
| Some value -> value
| None -> on_none ()
let panic fmt = Format.ksprintf failwith fmt
let find_f_rule ~program f =
find_exn ~key:f
~on_none:(fun () -> panic "find_f_rule: no such function `%s`" f)
program.f_rules
let find_g_rule_list ~program g =
find_exn ~key:g
~on_none:(fun () -> panic "find_g_rule_list: no such function `%s`" g)
program.g_rules
let find_g_rule ~program (g, c) =
find_g_rule_list ~program g
|> find_exn ~key:c ~on_none:(fun () ->
panic "find_g_rule: no such pattern `%s` for `%s`" c g)
(******************************************************************************)
(* Fresh identifiers generation. *)
let var_counter, g_counter, f_counter = (ref 0, ref 0, ref 0)
let fresh_id ?(counter = var_counter) ?(prefix = "v") () =
let x = !counter in
counter := !counter + 1;
prefix ^ string_of_int x
let fresh_var_list length_list = List.map (fun _ -> fresh_id ()) length_list
let fresh_ctr (c, fresh_vars) = ctr_call (c, List.map var fresh_vars)
(******************************************************************************)
(* Substitution on terms. *)
module Subst = String_map
let scan_subst ~f subst = subst |> Subst.to_seq |> Seq.find_map f
let rec subst ~map = function
| Var x as default -> Subst.find_opt x map |> Option.value ~default
| Call (kind, h, args) -> Call (kind, h, List.map (subst ~map) args)
let subst_params (params, args) t =
subst ~map:(Subst.of_list (List.combine params args)) t
type contraction = { c : string; fresh_vars : string list }
[@@deriving eq, show { with_path = false }]
let unify ~x ~contraction:{ c; fresh_vars } args =
List.map (subst_params ([ x ], [ fresh_ctr (c, fresh_vars) ])) args
(******************************************************************************)
(* Term matching. *)
(* Tests whether [t2] is an instance of [t1]. If it is, returns a substitution
[map] such that [equal_term (subst ~map t1) t2]. *)
let match_against (t1, t2) =
let exception Fail in
let subst = ref Subst.empty in
let rec go (t1, t2) =
match t1 with
| Var x -> (
match Subst.find_opt x !subst with
| Some x_subst -> if not (equal_term x_subst t2) then raise Fail
| None -> subst := Subst.add x t2 !subst)
| Call (kind, h, args) -> (
match t2 with
| Call (kind', h', args') when equal_call_kind kind kind' && h = h' ->
List.iter2 (fun arg arg' -> go (arg, arg')) args args'
| _ -> raise Fail)
in
try
go (t1, t2);
Some !subst
with Fail -> None
(******************************************************************************)
(* Normal-order reduction ("driving"). *)
type 'a step =
(* Stop on a free variable. *)
| Stop of string
(* Decompose a constructor. *)
| Decompose of string * 'a list
(* Unfold a call. *)
| Transient of 'a
(* Case-analysis of a g-call. *)
| Variants of string * (contraction * 'a) list
[@@deriving eq, show { with_path = false }]
(* Implements "Fig. 6. Normal-order reduction step". *)
let rec reduce ~program = function
| Var x -> Stop x
| Call (CtrKind, c, args) -> Decompose (c, args)
| Call (FKind, f, args) ->
let params, s = find_f_rule ~program f in
Transient (subst_params (params, args) s)
| Call (GKind, g, Call (CtrKind, c, args) :: args') ->
let params, params', s = find_g_rule ~program (g, c) in
Transient (subst_params (params @ params', args @ args') s)
| Call (GKind, g, Var x :: args) ->
Variants (x, reduce_variants ~program (x, g, args))
| Call (GKind, g, (Call _ as t) :: args) -> reduce_inner ~program (t, g, args)
| Call (GKind, g, []) -> panic "reduce: no scrutinee for `%s`" g
and reduce_variants ~program (x, g, args) =
let unfold_rule (c, (params, params', s)) =
let fresh_vars = fresh_var_list params in
let contraction = { c; fresh_vars } in
let args = List.map var fresh_vars @ unify ~x ~contraction args in
(contraction, subst_params (params @ params', args) s)
in
find_g_rule_list ~program g
|> G_rules_by_name.bindings |> List.map unfold_rule
and reduce_inner ~program (t, g, args) =
match reduce ~program t with
| Transient t -> Transient (g_call (g, t, args))
| Variants (x, variants) ->
let propagate_contraction (contraction, t) =
(contraction, g_call (g, t, unify ~x ~contraction args))
in
Variants (x, List.map propagate_contraction variants)
| _ -> panic "reduce_inner: impossible"
(******************************************************************************)
(* The homeomorphic embedding relation. *)
(* Implements "Fig. 7. Homeomorphic embedding". *)
let rec decide_he = function
| Var _x, Var _y -> true
| s, (Call _ as t) -> he_by_diving (s, t) || he_by_coupling (s, t)
| Call _, Var _ -> false
and he_by_diving = function
| s, Call (_kind, _h, args) -> List.exists (fun t -> decide_he (s, t)) args
| _, _ -> false
and he_by_coupling = function
| Call (kind, h, args), Call (kind', h', args')
when equal_call_kind kind kind' && h = h' ->
List.for_all2 (fun s t -> decide_he (s, t)) args args'
| _, _ -> false
(******************************************************************************)
(* Computing MSG, a "Most Specific Generalization". *)
let common_functor (g, subst_1, subst_2) =
subst_1
|> scan_subst ~f:(function
| x, Call (kind, h, args) ->
subst_2
|> scan_subst ~f:(function
| y, Call (kind', h', args')
when x = y && equal_call_kind kind kind' && h = h' ->
Some (x, kind, h, args, args')
| _ -> None)
| _ -> None)
|> Option.map (fun (x, kind, h, args, args') ->
let fresh_vars = fresh_var_list args in
let common_call = Call (kind, h, List.map var fresh_vars) in
let new_subst ~args old_subst =
List.fold_left
(fun acc (fresh_var, arg) -> Subst.add fresh_var arg acc)
(Subst.remove x old_subst)
(List.combine fresh_vars args)
in
( subst ~map:(Subst.singleton x common_call) g,
new_subst ~args subst_1,
new_subst ~args:args' subst_2 ))
let common_subst (g, subst_1, subst_2) =
subst_1
|> scan_subst ~f:(fun (x, s) ->
subst_1
|> scan_subst ~f:(fun (y, s') ->
if
x != y && equal_term s s'
&& Option.bind (Subst.find_opt x subst_2) (fun t ->
Subst.find_opt y subst_2
|> Option.map (fun t' -> equal_term t t'))
|> Option.value ~default:false
then Some (x, y)
else None))
|> Option.map (fun (x, y) ->
let new_subst old_subst = Subst.remove x old_subst in
( subst ~map:(Subst.singleton x (Var y)) g,
new_subst subst_1,
new_subst subst_2 ))
let step ~rules triple = List.find_map (fun f -> f triple) rules
(* Exhaustively applies the two rewrite rules. *)
let rec loop triple =
match step ~rules:[ common_functor; common_subst ] triple with
| Some triple' -> loop triple'
| None -> triple
(* Implements "Fig. 10. Computing most specific generalizations". *)
let compute_msg (t1, t2) =
let seed = fresh_id () in
let g, subst_1, subst_2 =
(Var seed, Subst.singleton seed t1, Subst.singleton seed t2)
in
loop (g, subst_1, subst_2)
(******************************************************************************)
(* The supercompilation algorithm itself! *)
type graph = Step of graph step | Bind of (string * graph) list * bind_kind
and bind_kind = Let of graph | Fold of de_bruijn_idx
and de_bruijn_idx = int [@@deriving eq, show { with_path = false }]
(* Implements "Algorithm 7 (positive supercompilation.)". *)
let supercompile ~program t =
let rec go ~history n =
match
List.find_mapi
(fun i m -> if decide_he (m, n) then Some (i, m) else None)
history
with
(* We must stop direct unfolding to avoid non-termination. *)
| Some (i, m) -> (
match match_against (m, n) with
(* [n] is an instance of [m]. *)
| Some subst -> fold ~history:(n :: history) ~subst i
| None -> (
let g, _subst_1, subst_2 = compute_msg (m, n) in
match g with
(* [m] and [n] are disjoint. *)
| Var _x -> split ~history:(n :: history) n
(* [m] and [n] have a "meaningful" generalization. *)
| _ -> generalize ~history:(n :: history) ~subst_2 g))
(* Everything is good, continue unfolding. *)
| None -> Step (step ~history:(n :: history) n)
and fold ~history ~subst i =
Bind
( subst |> Subst.bindings
|> List.map (fun (x, child) -> (x, go ~history child)),
Fold i )
and split ~history = function
| Call (kind, h, args) ->
let fresh_vars = fresh_var_list args in
let t_g = Call (kind, h, List.map var fresh_vars) in
Bind
( List.map2 (fun x t -> (x, go ~history t)) fresh_vars args,
Let (go ~history t_g) )
| _ -> panic "split: impossible"
and generalize ~history ~subst_2 g =
Bind
( subst_2 |> Subst.bindings
|> List.map (fun (x, child) -> (x, go ~history child)),
Let (go ~history g) )
and step ~history n =
match reduce ~program n with
| Stop x -> Stop x
| Decompose (c, args) -> Decompose (c, List.map (go ~history) args)
| Transient t -> Transient (go ~history t)
| Variants (x, variants) ->
Variants
( x,
List.map
(fun (contraction, t) -> (contraction, go ~history t))
variants )
in
go ~history:[] t
(******************************************************************************)
(* Residual program generation. *)
module Free_vars = Set.Make (String)
let union_free_vars list =
let out_list, free_vars_list = List.split list in
let free_vars =
List.fold_left Free_vars.union Free_vars.empty free_vars_list
in
(out_list, free_vars)
(* Invariants:
- We always arrange function parameters lexicographically. This ensures that
recursive functions are called with proper argument positions.
- We generate only recursive f-functions. This means that useless functions
are "inlined" automatically.
- When processing a graph node (either [Step] or [Bind]), the history length
must be the same as in [supercompile].
*)
let residuate graph =
let f_rules, g_rules = (ref F_rules.empty, ref G_rules_by_name.empty) in
let gen_f ~def_meta:(fresh_f, is_called) ~free_vars t_out =
if !is_called then (
(* If the function is called recursively, generate a new definition. *)
let free_vars_list = Free_vars.elements free_vars in
f_rules := F_rules.add fresh_f (free_vars_list, t_out) !f_rules;
(f_call (fresh_f, List.map var free_vars_list), free_vars))
else
(* Otherwise, just return the term that would have resulted in the
function's body. In essence, this performs definition inlining, but
without generating a definition! *)
(t_out, free_vars)
in
let f_meta () =
let fresh_f = fresh_id ~counter:f_counter ~prefix:"f" () in
let is_called = ref false in
(fresh_f, is_called)
in
let rec go ~history = function
| Step (Stop x) -> (Var x, Free_vars.singleton x)
| Step (Decompose (c, args)) -> go_decompose ~history (c, args)
| Step (Transient t) -> go_transient ~history t
| Step (Variants (x, variants)) -> go_variants ~history (x, variants)
| Bind (bindings, Let t) -> go_bind_let ~history (bindings, t)
| Bind (bindings, Fold i) -> go_bind_fold ~history (bindings, i)
(* Generate an f-function definition that calls a constructor. *)
and go_decompose ~history (c, args) =
let def_meta = f_meta () in
let args_out, free_vars =
args |> List.map (go ~history:(`F def_meta :: history)) |> union_free_vars
in
gen_f ~def_meta ~free_vars (ctr_call (c, args_out))
(* Generate a transient f-function definition. *)
and go_transient ~history t =
let def_meta = f_meta () in
let t_out, free_vars = go ~history:(`F def_meta :: history) t in
gen_f ~def_meta ~free_vars t_out
(* Generate a g-function definition that describes the variants. *)
and go_variants ~history (x, variants) =
let fresh_g = fresh_id ~counter:g_counter ~prefix:"g" () in
let variants_out, free_vars =
variants
|> List.map (fun (({ fresh_vars; _ } as contraction), t) ->
let t_out, free_vars = go ~history:(`G fresh_g :: history) t in
let free_vars = Free_vars.(diff free_vars (of_list fresh_vars)) in
((contraction, t_out), free_vars))
|> union_free_vars
in
let free_vars_list = Free_vars.elements free_vars in
let rules =
variants_out
|> List.map (fun ({ c; fresh_vars }, t_out) ->
(c, (fresh_vars, free_vars_list, t_out)))
|> G_rules_by_pattern.of_list
in
g_rules := G_rules_by_name.add fresh_g rules !g_rules;
( g_call (fresh_g, Var x, List.map var free_vars_list),
Free_vars.add x free_vars )
(* Residuate the bindings and substitute them to the body. *)
and go_bind_let ~history (bindings, t) =
let def_meta = f_meta () in
let bindings_out, free_vars =
go_bindings ~history:(`F def_meta :: history) bindings
in
(* Due to the way how an MSG is computed, [_free_vars'] is the set of fresh
variables from [bindings_out]. *)
let t_out, _free_vars' = go ~history:(`F def_meta :: history) t in
gen_f ~def_meta ~free_vars (subst ~map:(Subst.of_list bindings_out) t_out)
(* Residuate the bindings and make a recursive function call. *)
and go_bind_fold ~history (bindings, i) =
let def_meta = f_meta () in
let bindings_out, free_vars =
go_bindings ~history:(`F def_meta :: history) bindings
in
let args = List.map (fun (_x, t) -> t) bindings_out in
let t_out =
match List.nth history i with
| `F (f, is_called) ->
(* We will generate a definition for this recursive function [f]. *)
is_called := true;
f_call (f, args)
| `G g -> Call (GKind, g, args)
in
gen_f ~def_meta ~free_vars t_out
and go_bindings ~history bindings =
bindings
|> List.map (fun (x, t) ->
let t_out, free_vars = go ~history t in
((x, t_out), free_vars))
|> union_free_vars
in
let t_out, _free_vars = go ~history:[] graph in
({ f_rules = !f_rules; g_rules = !g_rules }, t_out)
(******************************************************************************)
(* Initial program construction. *)
let create_program ?(f_rules = []) ?(g_rules = []) () =
let prepare_f_rule (f, params, t) = (f, (params, t)) in
let prepare_g_rule_by_name (g, rules) =
let prepare_g_rule_by_pattern ((c, params), params', t) =
(c, (params, params', t))
in
( g,
rules |> List.map prepare_g_rule_by_pattern |> G_rules_by_pattern.of_list
)
in
{
f_rules = f_rules |> List.map prepare_f_rule |> F_rules.of_list;
g_rules =
g_rules |> List.map prepare_g_rule_by_name |> G_rules_by_name.of_list;
}
(******************************************************************************)
(* Pretty-printing. *)
let rec term_to_string = function
| Var x -> x
| Call (_kind, h, args) ->
Printf.sprintf "%s(%s)" h
(String.concat ", " (List.map term_to_string args))
let print_program { f_rules; g_rules } =
let go_f_rule (f, params, t) =
Printf.sprintf "%s(%s) = %s" f
(String.concat ", " params)
(term_to_string t)
in
let go_g_rule (g, (c, params), params', t) =
Printf.sprintf "%s(%s(%s)%s%s) = %s" g c
(String.concat ", " params)
(if List.is_empty params' then "" else ", ")
(String.concat ", " params')
(term_to_string t)
in
let f_rules =
f_rules |> F_rules.bindings
|> List.map (fun (f, (params, t)) -> go_f_rule (f, params, t))
in
let g_rules =
g_rules |> G_rules_by_name.bindings
|> List.map (fun (g, rules) ->
rules |> G_rules_by_pattern.bindings
|> List.map (fun (c, (params, params', t)) ->
go_g_rule (g, (c, params), params', t)))
|> List.flatten
in
f_rules @ g_rules
(******************************************************************************)
(* Program optimization. *)
let optimize ?f_rules ?g_rules t =
var_counter := 0;
f_counter := 0;
g_counter := 0;
let program = create_program ?f_rules ?g_rules () in
let program_out, t_out = supercompile ~program t |> residuate in
(term_to_string t_out, print_program program_out)
@hirrolot
Copy link
Author

I've found a bug in the implementation. To see, add the following lines to test_optimize:

  let f_rules = [ ("f", [ "x"; "y" ], f_call ("f", [ Var "x"; Var "x" ])) ] in
  (* f(x, y) *)
  check
    ~expected:("f0(x)", [ "f0(x) = f0(x, x)" ])
    ~f_rules
    (f_call ("f", [ Var "x"; Var "y" ]))

The trouble is that f0 is defined as a single-parameter function, but it is called with two parameters inside itself. This happens because the residualizer counts distinct variables in a function's body (to be created), while f is originally defined with both x and y. To fix the problem, there needs to be a pass before residualization that inspects every node that will be residualized into a function, and associates this node's identifier with its signature containing all its "real" parameter names 1.

I won't be fixing it in this implementation, but if you're interested, my supercompiler Mazeppa doesn't have this problem.

Footnotes

  1. The parameter names can be derived from the binding list of Bind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment