Skip to content

Instantly share code, notes, and snippets.

@dustinvtran
Last active March 14, 2016 01:56
Show Gist options
  • Save dustinvtran/ce6c064a8db49ac2d41c to your computer and use it in GitHub Desktop.
Save dustinvtran/ce6c064a8db49ac2d41c to your computer and use it in GitHub Desktop.
Assignment (pass by reference) in TensorFlow
# This note is to show that TensorFlow objects are shallow copied:
# a deep copy is not made when a class contains another class.
#
# This is relevant for variational parameters. For example, Inference
# holds Variational, we train variational parameters via
# inference.run(). We can use the original variational object as it will
# have the trained parameters, and not just resort to using
# inference.variational.
# (This is probably a feature in Python in general.)
from __future__ import print_function
import tensorflow as tf
class ClassWithVar:
def __init__(self):
self.var = tf.Variable(tf.random_normal([1]))
self.non_var = tf.constant(0.0)
class Container:
def __init__(self, x):
self.x = x
sess = tf.InteractiveSession()
x = ClassWithVar()
y = Container(x)
init = tf.initialize_all_variables()
sess.run(init)
print(x.var.eval())
print(y.x.var.eval())
print(x.non_var.eval())
print(y.x.non_var.eval())
y.x.var += tf.constant(1.0)
y.x.non_var += tf.constant(1.0)
print(x.var.eval())
print(y.x.var.eval())
print(x.non_var.eval())
print(y.x.non_var.eval())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment