Skip to content

Instantly share code, notes, and snippets.

@dwf
Created March 27, 2012 19:47
Show Gist options
  • Save dwf/2219652 to your computer and use it in GitHub Desktop.
Save dwf/2219652 to your computer and use it in GitHub Desktop.
Serialization-friendly cached Theano-functions-as-instance-attributes.
import theano
import functools
def cached_theano_function(fn):
@functools.wraps(fn)
def wrapped(self):
if not hasattr(self, '_function_cache'):
self._function_cache = {}
if fn.func_name not in self._function_cache:
self._function_cache[fn.func_name] = fn(self)
return self._function_cache[fn.func_name]
return property(wrapped)
class MyModel(object):
@cached_theano_function
def mul(self):
print "Compiling mul! You will only see this once."
X = theano.tensor.matrix()
Y = theano.tensor.vector()
Z = theano.tensor.dot(X, Y)
return theano.function([X, Y], Z)
def __getstate__(self):
"""Shove this in a base class somewhere."""
result = self.__dict__.copy()
if '_function_cache' in result:
del result['_function_cache']
return result
if __name__ == "__main__":
from numpy.random import randn
a = MyModel()
a.mul(randn(5, 3), randn(3))
a.mul(randn(5, 3), randn(3))
a.mul(randn(5, 3), randn(3))
a.mul(randn(5, 3), randn(3))
import cPickle
with open('test.pkl', 'wb') as f:
cPickle.dump(a, f)
print "saved, now loading back..."
with open('test.pkl', 'r') as f:
b = cPickle.load(f)
b.mul(randn(5, 3), randn(3))
b.mul(randn(5, 3), randn(3))
b.mul(randn(5, 3), randn(3))
b.mul(randn(5, 3), randn(3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment