Skip to content

Instantly share code, notes, and snippets.

@dwf
Created January 27, 2010 09:35
Show Gist options
  • Save dwf/287695 to your computer and use it in GitHub Desktop.
Save dwf/287695 to your computer and use it in GitHub Desktop.
Object-specific NumPy random number generator states
"""
A demonstration of how to deal with object-specific random number
generator states. One application is if you want two objects to
work with the same pseudorandom sequence but don't particularly
want to generate them in advance, or to replicate results after
serializing and de-serializing an object.
By David Warde-Farley, dwf at cs.toronto.edu, January 2010.
Released under BSD license.
"""
from functools import wraps
import numpy as np
def uses_rng(method):
"""
Decorator that allows methods to interact with the NumPy random
number generator using an object-specific random number generator
state.
The decorated method saves the current PRNG state and replaces
it with the contents of self._rng before executing the
original method, and reinstates the saved state before returning.
"""
@wraps(method)
def decorated(self, *args, **kwargs):
if hasattr(self, '_rngstate'):
oldstate = np.random.get_state()
try:
np.random.set_state(self._rngstate)
returnval = method(*args, **kwargs)
finally:
# Clean up after ourselves _even in case of an error_
self._rngstate = np.random.get_state()
np.random.set_state(oldstate)
return returnval
else:
raise AttributeError("object doesn't have saved random state")
return decorated
class StatefulRandom(object):
"""A superclass for objects that use their own RNG state."""
def __init__(self, *args, **kwargs):
if 'rngstate' in kwargs:
self._rngstate = rngstate
del kwargs['rngstate']
else:
self._rngstate = np.random.get_state()
super(StatefulRandom, self).__init__(*args, **kwargs)
if __name__ == "__main__":
class Foo(StatefulRandom):
"""A class that uses a random number generator."""
@uses_rng
def bar():
"""A method that uses a random number generator."""
print np.random.randn()
a = Foo()
a.bar()
a.bar()
a.bar()
a.bar()
print "-------"
# Created after several queries to the first object's RNG. Note that they
# are both using the global state to initialize and at this point the
# global state should appear the same as it did when the first object
# was created.
b = Foo()
# They started with the same RNG state, so they should generate the same
# numbers. The same four numbers should be printed before/after the line.
b.bar()
b.bar()
b.bar()
b.bar()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment