Skip to content

Instantly share code, notes, and snippets.

@danijar
Last active November 20, 2021 17:21
Show Gist options
  • Save danijar/720394a9071a03413be8a60852374aa4 to your computer and use it in GitHub Desktop.
Save danijar/720394a9071a03413be8a60852374aa4 to your computer and use it in GitHub Desktop.
TensorFlow decorator to share variables between calls. Works for both functions and methods.
import functools
import tensorflow as tf
class share_variables(object):
def __init__(self, callable_):
self._callable = callable_
self._wrappers = {}
self._wrapper = None
def __call__(self, *args, **kwargs):
return self._function_wrapper(*args, **kwargs)
def __get__(self, instance, owner):
decorator = self._method_wrapper
decorator = functools.partial(decorator, instance)
decorator = functools.wraps(self._callable)(decorator)
return decorator
def _method_wrapper(self, instance, *args, **kwargs):
if instance not in self._wrappers:
name = self._create_name(
type(instance).__module__,
type(instance).__name__,
instance.name if hasattr(instance, 'name') else id(instance),
self._callable.__name__)
self._wrappers[instance] = tf.make_template(
name, self._callable, create_scope_now_=True)
return self._wrappers[instance](instance, *args, **kwargs)
def _function_wrapper(self, *args, **kwargs):
if not self._wrapper:
name = self._create_name(
self._callable.__module__,
self._callable.__name__)
self._wrapper = tf.make_template(
name, self._callable, create_scope_now_=True)
return self._wrapper(*args, **kwargs)
def _create_name(self, *words):
words = [str(word) for word in words]
words = [word.replace('_', '') for word in words]
return '_'.join(words)
class Model(object):
def __init__(self, name):
self.name = name
@share_variables
def method(self, data):
return tf.layers.dense(data, 100)
@share_variables
def function(data):
return tf.layers.dense(data, 50)
data = tf.placeholder(tf.float32, [None, 50])
function(data)
function(data)
foo = Model('foo')
foo.method(data)
foo.method(data)
bar = Model('bar')
bar.method(data)
for var in tf.trainable_variables():
print(var.name)
# Output:
# main_function/dense/kernel:0
# main_function/dense/bias:0
# main_Model_foo_method/dense/kernel:0
# main_Model_foo_method/dense/bias:0
# main_Model_bar_method/dense/kernel:0
# main_Model_bar_method/dense/bias:0
@danijar
Copy link
Author

danijar commented Jun 26, 2018

I've updated the code to include a fix and to use the self.name attribute of model classes if available, and fall back to id(self) otherwise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment