Last active
March 14, 2016 01:56
-
-
Save dustinvtran/ce6c064a8db49ac2d41c to your computer and use it in GitHub Desktop.
Assignment (pass by reference) in TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This note is to show that TensorFlow objects are shallow copied: | |
# a deep copy is not made when a class contains another class. | |
# | |
# This is relevant for variational parameters. For example, Inference | |
# holds Variational, we train variational parameters via | |
# inference.run(). We can use the original variational object as it will | |
# have the trained parameters, and not just resort to using | |
# inference.variational. | |
# (This is probably a feature in Python in general.) | |
from __future__ import print_function | |
import tensorflow as tf | |
class ClassWithVar: | |
def __init__(self): | |
self.var = tf.Variable(tf.random_normal([1])) | |
self.non_var = tf.constant(0.0) | |
class Container: | |
def __init__(self, x): | |
self.x = x | |
sess = tf.InteractiveSession() | |
x = ClassWithVar() | |
y = Container(x) | |
init = tf.initialize_all_variables() | |
sess.run(init) | |
print(x.var.eval()) | |
print(y.x.var.eval()) | |
print(x.non_var.eval()) | |
print(y.x.non_var.eval()) | |
y.x.var += tf.constant(1.0) | |
y.x.non_var += tf.constant(1.0) | |
print(x.var.eval()) | |
print(y.x.var.eval()) | |
print(x.non_var.eval()) | |
print(y.x.non_var.eval()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment