Last active
November 13, 2024 16:16
-
-
Save swo/7a8b75d6f996933a264ab9d1f445b9f6 to your computer and use it in GitHub Desktop.
Why use `np.random.Generator` objects
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class MyRNG: | |
def __init__(self, seed): | |
# store the global state, so we can reset it | |
global_state = np.random.get_state() | |
# set the seed and get the state | |
np.random.seed(seed) | |
self.state = np.random.get_state() | |
# reset the state | |
np.random.set_state(global_state) | |
def normal(self, *args, **kwargs): | |
# store the global seed, so we can reset it later | |
global_state = np.random.get_state() | |
# set the seed to what it should be for this call | |
np.random.set_state(self.state) | |
# get the random value | |
value = np.random.normal(*args, **kwargs) | |
# store the state | |
self.state = np.random.get_state() | |
# restore the global state | |
np.random.set_state(global_state) | |
return value | |
class GlobalSeedSimulation: | |
def __init__(self, N: int = 5, seed: int = 1234): | |
self.N = N | |
self.results = [] | |
self.seed = seed | |
def step(self): | |
assert len(self.results) < self.N | |
# note that this calls the global numpy RNG | |
value = np.random.normal() | |
self.results.append(value) | |
def run(self): | |
# set the seed | |
np.random.seed(self.seed) | |
# do the steps | |
while len(self.results) < self.N: | |
self.step() | |
class GeneratorSimulation: | |
def __init__(self, rng_fun, N: int = 5, seed: int = 1234): | |
self.N = N | |
self.results = [] | |
# store a numpy.random.Generator object | |
self.rng = rng_fun(seed) | |
def step(self): | |
assert len(self.results) < self.N | |
# note that this queries the simulation-specific | |
# RNG for the next value | |
value = self.rng.normal() | |
self.results.append(value) | |
def run(self): | |
while len(self.results) < self.N: | |
self.step() | |
# generate two each of the two types of simulation | |
seed_sim1 = GlobalSeedSimulation() | |
seed_sim2 = GlobalSeedSimulation() | |
gen_sim1 = GeneratorSimulation(np.random.default_rng) | |
gen_sim2 = GeneratorSimulation(np.random.default_rng) | |
my_sim1 = GeneratorSimulation(MyRNG) | |
my_sim2 = GeneratorSimulation(MyRNG) | |
# imagine that we took a step for only one of each simulation type | |
seed_sim1.step() | |
gen_sim1.step() | |
my_sim1.step() | |
# but then we finished running for all simulations | |
seed_sim1.run() | |
seed_sim2.run() | |
gen_sim1.run() | |
gen_sim2.run() | |
my_sim1.run() | |
my_sim2.run() | |
# note that the two generator sims give the same values (starting with -1.603) | |
# as each other, and across runs of the program. | |
print("gen1", gen_sim1.results) | |
print("gen2", gen_sim2.results) | |
# and same is true for my sims | |
print("my1", my_sim1.results) | |
print("my22", my_sim2.results) | |
# but the global seed simulations give different values every single time! | |
print("seed1", seed_sim1.results) | |
print("seed2", seed_sim2.results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment