Last active
April 17, 2019 06:46
-
-
Save abhayraw1/af7ab2bb524b0392c6495c4d8d90c3f4 to your computer and use it in GitHub Desktop.
Simple Feed Forward Neural Net Using tf.keras written for reproducing issue #27316 mentioned in tensorflow/issues
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import tensorflow.keras.layers as L | |
import tensorflow.keras.activations as Z | |
from tensorflow.keras import Model | |
from pprint import pprint | |
def feed_forward_nn(input_shape): | |
with tf.variable_scope("keras_"): | |
_input_ = L.Input(shape=input_shape) | |
layer_1 = L.Dense(5, Z.relu)(_input_) | |
output = L.Dense(1, Z.tanh)(layer_1) | |
return Model(inputs=_input_, outputs=output) | |
def main(): | |
model1 = feed_forward_nn((3, )) | |
model2 = feed_forward_nn((3, )) | |
print("Model weights before:\n") | |
pprint(model1.get_weights()) | |
pprint(model2.get_weights()) | |
print() | |
model1.set_weights(model2.get_weights()) | |
are_wts_equal = all([(i == j).all() for i, j in zip(model1.get_weights(), model2.get_weights())]) | |
print("Model weights after:\n") | |
pprint(model1.get_weights()) | |
pprint(model2.get_weights()) | |
print() | |
print("\nAre Model Weights Equal (using <tf.keras.Model>.get_weights()): {}".format(are_wts_equal)) | |
print("\nIf YES then these should be equal!!!") | |
some_input = np.random.random((1, 3)) | |
print(model1.predict(some_input)) | |
print(model2.predict(some_input)) | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
sess.run(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) | |
print(model1.predict(some_input)) | |
print(model2.predict(some_input)) | |
print(sess.run(model1.trainable_weights)) | |
print(sess.run(model2.trainable_weights)) | |
return model1, model2 | |
if __name__ == "__main__": | |
model1, model2 = main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment