Created
January 20, 2015 05:17
-
-
Save akabe/83da3bba75fbf83c56d2 to your computer and use it in GitHub Desktop.
Unfolding Recursive Autoencoder and Online Backpropagation Through Structure (BPTS)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* unfoldingRAE.ml --- Unfolding Recursive Autoencoder and | |
Online Backpropagation Through Structure (BPTS) | |
[MIT License] Copyright (C) 2015 Akinori ABE | |
Compilation: | |
$ ocamlfind ocamlopt -linkpkg -package slap unfoldingRAE.ml | |
This program requires Sized Linear Algebra Package (SLAP), a linear algebra | |
library for OCaml with static size checking for matrix operations (see | |
http://akabe.github.io/slap/ for details). | |
*) | |
open Format | |
open Slap.Io | |
open Slap.D | |
open Slap.Common | |
module Size = Slap.Size | |
type 'n tree = Node of ('n, Slap.cnt) vec (* encoder output *) | |
* ('n, Slap.cnt) vec (* decoder output *) | |
* 'n tree list (* children *) | |
type 'n params = { | |
wleft : ('n, 'n, Slap.cnt) mat; | |
wright : ('n, 'n, Slap.cnt) mat; | |
bias : ('n, Slap.cnt) vec; | |
} | |
let make_params feature_dim = | |
{ | |
wleft = Mat.random feature_dim feature_dim; | |
wright = Mat.random feature_dim feature_dim; | |
bias = Vec.random feature_dim; | |
} | |
let fill0_params {wleft; wright; bias} = | |
Mat.fill wleft 0.0; | |
Mat.fill wright 0.0; | |
Vec.fill bias 0.0 | |
(* ========================================================================== * | |
* Activation functions | |
* ========================================================================== *) | |
(** The activation function of an encoder (hyperbolic tangent) *) | |
let actv_f ?y x = Vec.map ?y tanh x | |
let actv_df y = | |
let ones = Vec.make1 (Vec.dim y) in | |
Vec.sub ones (Vec.mul y y) | |
(** The activation function of a decoder (linear) *) | |
let actv_g ?y x = Vec.copy ?y x | |
let actv_dg y = Vec.make1 (Vec.dim y) | |
(* ========================================================================== * | |
* Utility functions | |
* ========================================================================== *) | |
(** [gemv_wmat params cleft cright ~trans ~y x] computes | |
[y := (cleft * params.wleft + cright * params.wright) * x + y] | |
if [trans] is {!Slap.Common.normal}. When [trans] is {!Slap.Common.trans}, | |
this function executes | |
[y := (cleft * params.wleft + cright * params.wright)^T * x + y]. | |
*) | |
let gemv_wmat {wleft; wright} cleft cright ~trans ~y x = | |
ignore (gemv ~trans ~alpha:cleft wleft x ~beta:1.0 ~y); | |
ignore (gemv ~trans ~alpha:cright wright x ~beta:1.0 ~y) | |
(** [map_children f [x(1); x(2); ...; x(n)]] is | |
[[f cleft(1) cright(1) x(1); | |
f cleft(2) cright(2) x(2); | |
...; | |
f cleft(n) cright(n) x(n)]] | |
where [n] is the length of the given list, and [cleft(i)] and [cright(i)] | |
are the [i]-th blend ratios between the left and right weight matrices, | |
respectively. | |
*) | |
let map_children f children = | |
let cright_lst = match List.length children with | |
| 0 -> assert(false) | |
| 1 -> [0.5] | |
| n -> let li = float (n - 1) in | |
List.mapi (fun i _ -> float i /. li) children in | |
List.map2 (fun cright -> f (1.0 -. cright) cright) cright_lst children | |
(** [iter_children f [x(1); x(2); ...; x(n)]] is | |
[f cleft(1) cright(1) x(1); | |
f cleft(2) cright(2) x(2); | |
...; | |
f cleft(n) cright(n) x(n)]. | |
*) | |
let iter_children (f : float -> float -> 'a -> unit) children = | |
ignore (map_children f children) | |
(** [prop_bottomup on_root on_node on_leaf tree] executes bottom-up propagation | |
for [tree]. [on_root] and [on_leaf] are called for the root and leaves of | |
[tree]. For other nodes, [on_node] is called. | |
*) | |
let prop_bottomup on_root on_node on_leaf (Node (x0, x0', children0)) = | |
let rec prop = function | |
| Node (x, x', []) -> on_leaf x x' (* a leaf *) | |
| Node (x, x', children) -> (* a non-leaf node *) | |
on_node x x' (List.map prop children) | |
in | |
on_root x0 x0' (List.map prop children0) | |
(** [prop_topdown init on_node tree] executes top-down propagation for [tree]. | |
[on_node] is called for non-leaf nodes. | |
*) | |
let prop_topdown init on_node tree = | |
let rec prop acc = function | |
| Node (_, _, []) -> () (* a leaf *) | |
| Node (_, _, children) -> (* a non-leaf node *) | |
List.iter2 prop (on_node acc children) children | |
in | |
prop init tree | |
(* ========================================================================== * | |
* Feedforward propagation | |
* ========================================================================== *) | |
let enc_feedforward params tree = | |
let on_node vec_p _ vec_cm_lst = | |
ignore (Vec.copy params.bias ~y:vec_p); (* vec(p) := bias *) | |
iter_children (gemv_wmat params ~trans:normal ~y:vec_p) vec_cm_lst; | |
actv_f ~y:vec_p vec_p (* vec(p) := f(vec(p)) *) | |
in | |
prop_bottomup on_node on_node (fun x _ -> x) tree | |
let dec_feedforward params (Node (vec_root, vec'_root, _) as tree) = | |
let on_node vec'_p children = | |
let on_child cleft cright (Node (_, vec'_cm, _)) = | |
ignore (copy ~y:vec'_cm params.bias); (* vec'(cm) := bias *) | |
ignore (gemv_wmat params ~trans:normal ~y:vec'_cm cleft cright vec'_p); | |
actv_g ~y:vec'_cm vec'_cm (* vec'(cm) := g(vec'(cm)) *) | |
in | |
map_children on_child children | |
in | |
ignore (copy ~y:vec'_root vec_root); (* vec'(root) := vec'(root) *) | |
prop_topdown vec'_root on_node tree | |
let feedforward dec_params enc_params tree = | |
ignore (enc_feedforward enc_params tree); | |
dec_feedforward dec_params tree | |
(** [calc_error dec_params enc_params tree] computes an error. *) | |
let calc_error dec_params enc_params tree = | |
feedforward dec_params enc_params tree; | |
let on_node _ _ l = List.fold_left (+.) 0.0 l in | |
prop_bottomup on_node on_node (fun x x' -> Vec.ssqr_diff x x' /. 2.0) tree | |
(* ========================================================================== * | |
* Feedback propagation | |
* ========================================================================== *) | |
let enc_feedback grads params delta_root tree = | |
let on_node delta_p children = | |
axpy ~alpha:1.0 ~x:delta_p grads.bias; (* Add to gradients of biases *) | |
let on_child cleft cright (Node (vec_cm, _, _)) = | |
(* Add to gradients of weights *) | |
ignore (ger ~alpha:cleft delta_p vec_cm grads.wleft); | |
ignore (ger ~alpha:cright delta_p vec_cm grads.wright); | |
(* Compute delta(cm) from delta(p) *) | |
let delta_cm = Vec.make0 (Vec.dim delta_p) in | |
ignore (gemv_wmat params ~trans:trans ~y:delta_cm cleft cright delta_p); | |
Vec.mul ~z:delta_cm delta_cm (actv_df vec_cm) | |
in | |
map_children on_child children | |
in | |
prop_topdown delta_root on_node tree | |
let dec_feedback grads params tree = | |
let backprop1 dactv vec'_p delta'_cm_lst = | |
let delta'_p = Vec.make0 (Vec.dim vec'_p) in | |
let on_child cleft cright delta'_cm = | |
(* Add to gradients *) | |
axpy ~alpha:1.0 ~x:delta'_cm grads.bias; | |
ignore (ger ~alpha:cleft delta'_cm vec'_p grads.wleft); | |
ignore (ger ~alpha:cright delta'_cm vec'_p grads.wright); | |
(* Add to delta'(p) *) | |
gemv_wmat params ~trans:trans ~y:delta'_p cleft cright delta'_cm | |
in | |
iter_children on_child delta'_cm_lst; | |
Vec.mul ~z:delta'_p delta'_p dactv | |
in | |
let on_leaf vec_l vec'_l = Vec.mul (Vec.sub vec'_l vec_l) (actv_dg vec'_l) in | |
let on_node _ vec'_p = backprop1 (actv_dg vec'_p) vec'_p in | |
let on_root vec_p vec'_p = backprop1 (actv_df vec_p) vec'_p in | |
prop_bottomup on_root on_node on_leaf tree | |
let feedback dec_grads dec_params enc_grads enc_params tree = | |
let delta_root = dec_feedback dec_grads dec_params tree in | |
enc_feedback enc_grads enc_params delta_root tree | |
(* ========================================================================== * | |
* Training and Gradient Checking | |
* ========================================================================== *) | |
(** This function checks whether given gradients are correct or not by | |
comparison with results of naive numerical differentiation. This routine is | |
only for checking implementation. The numerical differentiation is much | |
slower than back propagation. | |
cf. http://ufldl.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization | |
*) | |
let check_gradient dec_grads dec_params enc_grads enc_params tree = | |
let epsilon = 1e-4 in | |
let check_digits dE1 dE2 = (* Check 4 significant digits *) | |
let abs_dE1 = abs_float dE1 in | |
if abs_dE1 < 1e-9 | |
then abs_float dE2 < 1e-9 (* true if both `dE1' and `dE2' are nealy zero *) | |
else let diff = (dE1 -. dE2) *. (0.1 ** (floor (log10 abs_dE1) +. 1.0)) in | |
abs_float diff < epsilon (* true if 4 significant digits are the same *) | |
in | |
let get_error () = calc_error dec_params enc_params tree in | |
let check_vec label x dx = | |
let check i dE2 = | |
let elm = Vec.get_dyn x i in | |
Vec.set_dyn x i (elm +. epsilon); | |
let pos_err = get_error () in | |
Vec.set_dyn x i (elm -. epsilon); | |
let neg_err = get_error () in | |
Vec.set_dyn x i elm; (* restore *) | |
let dE1 = (pos_err -. neg_err) /. (2.0 *. epsilon) in | |
if not (check_digits dE1 dE2) | |
then eprintf "WARNING: %s[%d] naive diff = %.6g, backprop = %.6g@." | |
label i dE1 dE2 | |
in | |
Vec.iteri check dx | |
in | |
let check_mat label a da = | |
Mat.fold_topi (fun i () ai -> | |
let label' = label ^ "[" ^ (string_of_int i) ^ "]" in | |
check_vec label' ai (Mat.row_dyn da i)) () a | |
in | |
check_vec "dE/db" enc_params.bias enc_grads.bias; | |
check_mat "dE/dWleft" enc_params.wleft enc_grads.wleft; | |
check_mat "dE/dWright" enc_params.wright enc_grads.wright; | |
check_vec "dE/db'" dec_params.bias dec_grads.bias; | |
check_mat "dE/dWleft'" dec_params.wleft dec_grads.wleft; | |
check_mat "dE/dWright'" dec_params.wright dec_grads.wright | |
let train eta dec_grads dec_params enc_grads enc_params tree = | |
fill0_params dec_grads; | |
fill0_params enc_grads; | |
feedforward dec_params enc_params tree; | |
feedback dec_grads dec_params enc_grads enc_params tree; | |
(* check_gradient dec_grads dec_params enc_grads enc_params tree; *) | |
let alpha = ~-. eta in | |
let update_params grads params = | |
Mat.axpy ~alpha ~x:grads.wleft params.wleft; | |
Mat.axpy ~alpha ~x:grads.wright params.wright; | |
axpy ~alpha ~x:grads.bias params.bias | |
in | |
update_params dec_grads dec_params; | |
update_params enc_grads enc_params | |
let main () = | |
Random.self_init (); | |
let module N = (val Size.of_int_dyn 5 : Size.SIZE) in | |
let node children = Node (Vec.create N.value, Vec.create N.value, children) in | |
let leaf feature_vec = Node (feature_vec, Vec.create N.value, []) in | |
let tree = node [ | |
node [ | |
leaf (Vec.make N.value 0.1); | |
node [leaf (Vec.make N.value 0.2); | |
leaf (Vec.make N.value 0.3)]]; | |
leaf (Vec.make N.value 0.4); | |
node [leaf (Vec.make N.value 0.5); | |
leaf (Vec.make N.value 0.6); | |
leaf (Vec.make N.value 0.7)]; | |
] in | |
(* The initial parameters *) | |
let dec_grads = make_params N.value in | |
let dec_params = make_params N.value in | |
let enc_grads = make_params N.value in | |
let enc_params = make_params N.value in | |
let eta = ref 0.01 in (* a learning rate *) | |
(* Training *) | |
for i = 1 to 30000 do | |
train !eta dec_grads dec_params enc_grads enc_params tree; | |
if i mod 1000 = 0 | |
then | |
begin | |
printf "Loop #%d: error = %g@." i | |
(calc_error dec_params enc_params tree); | |
eta := !eta *. 0.999; | |
end | |
done; | |
(* Show results *) | |
feedforward dec_params enc_params tree; | |
let on_node _ _ _ = () in | |
prop_bottomup on_node on_node | |
(fun x x' -> printf "Leaf: @[original = [ %a]@\n\ | |
reconstructed = [ %a]@]@." | |
pp_rfvec x pp_rfvec x') tree | |
let () = main () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment