Last active
November 28, 2018 23:18
-
-
Save jvmncs/4f5af5307f6a615b7857ca2bfef6fcf3 to your computer and use it in GitHub Desktop.
TensorFlow equivalent of tfe_minimal.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# generic functions for loading model weights and input data | |
def provide_weights(): """Load model weights as TensorFlow objects.""" | |
def provide_input(): """Load input data as TensorFlow objects.""" | |
# get model weights/input data (both unencrypted) | |
w0, b0, w1, b1, w2, b2 = provide_weights() | |
x = provide_input() | |
# compute prediction | |
layer0 = tf.nn.relu((tf.matmul(x, w0) + b0)) | |
layer1 = tf.nn.relu((tf.matmul(layer0, w1) + b1)) | |
logits = tf.matmul(layer2, w2) + b2 | |
# get result of prediction and print | |
prediction_op = tf.Print(result, [logits], message="prediction: ", summarize=10) | |
# run graph execution in a tf.Session | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer(), tag="init") | |
sess.run(prediction_op, tag="prediction") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment