Skip to content

Instantly share code, notes, and snippets.

@ryanrhymes
Last active April 4, 2017 11:02
Show Gist options
  • Save ryanrhymes/aaff8fb7ab0d94813e0ca81421793631 to your computer and use it in GitHub Desktop.
Save ryanrhymes/aaff8fb7ab0d94813e0ca81421793631 to your computer and use it in GitHub Desktop.
#require "owl_neural";;
open Owl_neural;;
(* config the neural network *)
let nn = Feedforward.create ();;
let l0 = linear ~inputs:784 ~outputs:300 ~init_typ:Init.(Uniform (-0.075,0.075));;
let l1 = linear ~inputs:300 ~outputs:10 ~init_typ:Init.(Uniform (-0.075,0.075));;
Feedforward.add_layer nn l0;;
Feedforward.add_activation nn Activation.Tanh;;
Feedforward.add_layer nn l1;;
Feedforward.add_activation nn Activation.Softmax;;
print nn;;
let x, _, y = Dataset.load_mnist_train_data ();;
let x, y = Algodiff.AD.Mat x, Algodiff.AD.Mat y;;
(* plot loss history *)
let l = train nn x y;;
let p = Vec.sequential (Array.length l);;
let q = Vec.of_array l;;
Plot.plot p q;;
(* test the nn model *)
let x, y, _ = Dataset.load_mnist_test_data () in
let x, y = Dataset.draw_samples x y 10 in
test_model nn (Mat x) (Mat y);;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment