Created
March 27, 2012 19:47
-
-
Save dwf/2219652 to your computer and use it in GitHub Desktop.
Serialization-friendly cached Theano-functions-as-instance-attributes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import theano | |
| import functools | |
| def cached_theano_function(fn): | |
| @functools.wraps(fn) | |
| def wrapped(self): | |
| if not hasattr(self, '_function_cache'): | |
| self._function_cache = {} | |
| if fn.func_name not in self._function_cache: | |
| self._function_cache[fn.func_name] = fn(self) | |
| return self._function_cache[fn.func_name] | |
| return property(wrapped) | |
| class MyModel(object): | |
| @cached_theano_function | |
| def mul(self): | |
| print "Compiling mul! You will only see this once." | |
| X = theano.tensor.matrix() | |
| Y = theano.tensor.vector() | |
| Z = theano.tensor.dot(X, Y) | |
| return theano.function([X, Y], Z) | |
| def __getstate__(self): | |
| """Shove this in a base class somewhere.""" | |
| result = self.__dict__.copy() | |
| if '_function_cache' in result: | |
| del result['_function_cache'] | |
| return result | |
| if __name__ == "__main__": | |
| from numpy.random import randn | |
| a = MyModel() | |
| a.mul(randn(5, 3), randn(3)) | |
| a.mul(randn(5, 3), randn(3)) | |
| a.mul(randn(5, 3), randn(3)) | |
| a.mul(randn(5, 3), randn(3)) | |
| import cPickle | |
| with open('test.pkl', 'wb') as f: | |
| cPickle.dump(a, f) | |
| print "saved, now loading back..." | |
| with open('test.pkl', 'r') as f: | |
| b = cPickle.load(f) | |
| b.mul(randn(5, 3), randn(3)) | |
| b.mul(randn(5, 3), randn(3)) | |
| b.mul(randn(5, 3), randn(3)) | |
| b.mul(randn(5, 3), randn(3)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment