Created
January 27, 2010 09:35
-
-
Save dwf/287695 to your computer and use it in GitHub Desktop.
Object-specific NumPy random number generator states
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A demonstration of how to deal with object-specific random number | |
generator states. One application is if you want two objects to | |
work with the same pseudorandom sequence but don't particularly | |
want to generate them in advance, or to replicate results after | |
serializing and de-serializing an object. | |
By David Warde-Farley, dwf at cs.toronto.edu, January 2010. | |
Released under BSD license. | |
""" | |
from functools import wraps | |
import numpy as np | |
def uses_rng(method): | |
""" | |
Decorator that allows methods to interact with the NumPy random | |
number generator using an object-specific random number generator | |
state. | |
The decorated method saves the current PRNG state and replaces | |
it with the contents of self._rng before executing the | |
original method, and reinstates the saved state before returning. | |
""" | |
@wraps(method) | |
def decorated(self, *args, **kwargs): | |
if hasattr(self, '_rngstate'): | |
oldstate = np.random.get_state() | |
try: | |
np.random.set_state(self._rngstate) | |
returnval = method(*args, **kwargs) | |
finally: | |
# Clean up after ourselves _even in case of an error_ | |
self._rngstate = np.random.get_state() | |
np.random.set_state(oldstate) | |
return returnval | |
else: | |
raise AttributeError("object doesn't have saved random state") | |
return decorated | |
class StatefulRandom(object): | |
"""A superclass for objects that use their own RNG state.""" | |
def __init__(self, *args, **kwargs): | |
if 'rngstate' in kwargs: | |
self._rngstate = rngstate | |
del kwargs['rngstate'] | |
else: | |
self._rngstate = np.random.get_state() | |
super(StatefulRandom, self).__init__(*args, **kwargs) | |
if __name__ == "__main__": | |
class Foo(StatefulRandom): | |
"""A class that uses a random number generator.""" | |
@uses_rng | |
def bar(): | |
"""A method that uses a random number generator.""" | |
print np.random.randn() | |
a = Foo() | |
a.bar() | |
a.bar() | |
a.bar() | |
a.bar() | |
print "-------" | |
# Created after several queries to the first object's RNG. Note that they | |
# are both using the global state to initialize and at this point the | |
# global state should appear the same as it did when the first object | |
# was created. | |
b = Foo() | |
# They started with the same RNG state, so they should generate the same | |
# numbers. The same four numbers should be printed before/after the line. | |
b.bar() | |
b.bar() | |
b.bar() | |
b.bar() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment