Skip to content

Instantly share code, notes, and snippets.

@ryanrhymes
Created March 11, 2017 22:35
Show Gist options
  • Save ryanrhymes/d195c209f50d3224477db0885186cde4 to your computer and use it in GitHub Desktop.
Save ryanrhymes/d195c209f50d3224477db0885186cde4 to your computer and use it in GitHub Desktop.
let test_model nn x y =
Mat.iter2_rows (fun u v ->
Dataset.print_mnist_image (unpack_mat u);
let p = run_network u nn |> unpack_mat in
Owl.Mat.print p;
Printf.printf "prediction: %i\n" (let _, _, j = Owl.Mat.max_i p in j)
) x y
let _ =
let x, _, y = Dataset.load_mnist_train_data () in
for i = 1 to 500 do
let x', y' = Dataset.draw_samples x y 100 in
backprop nn (F 0.01) (Mat x') (Mat y')
|> Printf.printf "#%i : loss=%g\n" i
|> flush_all;
done;
let x, y, _ = Dataset.load_mnist_test_data () in
let x, y = Dataset.draw_samples x y 10 in
test_model nn (Mat x) (Mat y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment