Skip to content

Instantly share code, notes, and snippets.

@swo
Last active November 13, 2024 16:16
Show Gist options
  • Save swo/7a8b75d6f996933a264ab9d1f445b9f6 to your computer and use it in GitHub Desktop.
Save swo/7a8b75d6f996933a264ab9d1f445b9f6 to your computer and use it in GitHub Desktop.
Why use `np.random.Generator` objects
import numpy as np
class MyRNG:
def __init__(self, seed):
# store the global state, so we can reset it
global_state = np.random.get_state()
# set the seed and get the state
np.random.seed(seed)
self.state = np.random.get_state()
# reset the state
np.random.set_state(global_state)
def normal(self, *args, **kwargs):
# store the global seed, so we can reset it later
global_state = np.random.get_state()
# set the seed to what it should be for this call
np.random.set_state(self.state)
# get the random value
value = np.random.normal(*args, **kwargs)
# store the state
self.state = np.random.get_state()
# restore the global state
np.random.set_state(global_state)
return value
class GlobalSeedSimulation:
def __init__(self, N: int = 5, seed: int = 1234):
self.N = N
self.results = []
self.seed = seed
def step(self):
assert len(self.results) < self.N
# note that this calls the global numpy RNG
value = np.random.normal()
self.results.append(value)
def run(self):
# set the seed
np.random.seed(self.seed)
# do the steps
while len(self.results) < self.N:
self.step()
class GeneratorSimulation:
def __init__(self, rng_fun, N: int = 5, seed: int = 1234):
self.N = N
self.results = []
# store a numpy.random.Generator object
self.rng = rng_fun(seed)
def step(self):
assert len(self.results) < self.N
# note that this queries the simulation-specific
# RNG for the next value
value = self.rng.normal()
self.results.append(value)
def run(self):
while len(self.results) < self.N:
self.step()
# generate two each of the two types of simulation
seed_sim1 = GlobalSeedSimulation()
seed_sim2 = GlobalSeedSimulation()
gen_sim1 = GeneratorSimulation(np.random.default_rng)
gen_sim2 = GeneratorSimulation(np.random.default_rng)
my_sim1 = GeneratorSimulation(MyRNG)
my_sim2 = GeneratorSimulation(MyRNG)
# imagine that we took a step for only one of each simulation type
seed_sim1.step()
gen_sim1.step()
my_sim1.step()
# but then we finished running for all simulations
seed_sim1.run()
seed_sim2.run()
gen_sim1.run()
gen_sim2.run()
my_sim1.run()
my_sim2.run()
# note that the two generator sims give the same values (starting with -1.603)
# as each other, and across runs of the program.
print("gen1", gen_sim1.results)
print("gen2", gen_sim2.results)
# and same is true for my sims
print("my1", my_sim1.results)
print("my22", my_sim2.results)
# but the global seed simulations give different values every single time!
print("seed1", seed_sim1.results)
print("seed2", seed_sim2.results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment