Here lists the examples of using tfgraph to convert computations in Owl to be executed on TensorFlow.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import numpy as np | |
import random | |
# first condition: average: 4.5 | |
def crossed1(p1, p2): | |
x1, y1 = p1 | |
x2, y2 = p2 | |
# p1 should be within; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open Owl | |
open Algodiff.D | |
let rec desc ?(eta=F 0.01) ?(eps=1e-6) f x = | |
let g = (diff f) x in | |
if (unpack_flt g) < eps then x | |
else desc ~eta ~eps f Maths.(x - eta * g) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="UTF-8"> | |
<title>Owl - OCaml Scientic and Engineering Computing</title> | |
<script src="https://code.jquery.com/jquery-3.4.1.slim.min.js" integrity="sha384-J6qa4849blE2+poT4WnyKhv5vZF5SrPo0iEjwBvKU7imGFAV0wwj1yYfoRSJoZ+n" crossorigin="anonymous"></script> | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js" integrity="sha384-Q6E9RHvbIyZFJoft+2mJbHaEWldlvI9IOYy5n3zV9zzTtmI3UksdQRVvoxMfooAo" crossorigin="anonymous"></script> | |
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.4.1/js/bootstrap.min.js" integrity="sha384-wfSDF2E50Y2D1uUdj0O3uMBJnjuUD4Ih7YwaYd1iqfktj0Uod8GCExl3Og8ifwB6" crossorigin="anonymous"></script> |
Files:
config.ml
: MirageOS configuration file.simple_mnist.ml
: main logic of MNIST neural network.simple_mnist_weight.ml
: pre-trained weights of the neural network.
Information:
- Weight file size: 146KB.
- Model test accuracy: 92%.
- MirageOS: compile with Unix backend. The generated binary is 10MB. The other backends are not tested yet.
Files:
simple_mnist.ml
: main logic of MNIST neural network.simple_mnist_weight.ml
: pre-trained weights of the neural network.
Information:
- Weight file size: 146KB.
- Model test accuracy: 92%.
- MirageOS: compile with Unix backend. The generated binary is 10MB. The other backends are not tested yet. Here is the MirageOS version of this simple MNIST.
Usage:
This example is a example of MNIST-based CNN.
- Step 1 : running OCaml script
tfgraph_train.ml
, which generates a filetf_convert_mnist.pbtxt
in current directory. - Step 2 : make sure
tf_convert_mnist.pbtxt
andtfgraph_train.py
in the same graph; make sure Tensorflow/numpy etc. is installed. - Step 3 : execute
python tf_converter_mnist.py
, and the expected output on screen is the training progress. After each 100 steps, loss value and model accuracy will be shown.
Here we only assume the python script writer knows where to find the output node (in collection "result") and the placeholder names (x:0
).
This example is a example of MNIST-based CNN.
- Step 1 : running OCaml script
tfgraph_inf.ml
, which generates a filetf_convert_mnist.pbtxt
in current directory. - Step 2 : make sure
tf_convert_mnist.pbtxt
andtfgraph_inf.py
in the same graph; make sure Tensorflow/numpy etc. is installed. - Step 3 : execute
python tf_converter_mnist.py
, and the expected output in screen is an array of size[100]
, each element is a boolean value. There is also an value to indicate the inference accuracy.
Here we only assume the python script writer knows where to find the output node (in collection "result") and the placeholder names (x:0
).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open Printf | |
let () = | |
for i = 0 to Array.length Sys.argv - 1 do | |
printf "[%i] %s\n" i Sys.argv.(i) | |
done;; |
NewerOlder