Created
June 28, 2024 21:58
-
-
Save jongbinjung/9de032bfcc0e2f8d7a01b9791e214c32 to your computer and use it in GitHub Desktop.
Demo of how (not) to use numpy.random
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Demonstration of why you shouldn't set numpy.random.seed, and what you should do instead""" | |
import numpy as np | |
# When setting a seed, you usually mean "make things deterministic". | |
# However, because the Global numpy seed is _shared_, it is not a reliable source to guarantee deterministic outcomes. | |
# For example, assume we want to use numpy's choice, to choose five random numbers from an array: | |
some_array = np.linspace(1, 10, num=10) | |
np.random.choice(a=some_array, size=5) | |
# Note that the above will return different random values whenever it is run---by design. | |
# If we set a random seed, it will return the same value each time: | |
## Take 1: | |
np.random.seed(666) | |
np.random.choice(a=some_array, size=5) | |
# => array([ 3., 7., 10., 5., 4.]) | |
## Take 2: | |
np.random.seed(666) | |
np.random.choice(a=some_array, size=5) | |
# => array([ 3., 7., 10., 5., 4.]) | |
# Same values, because the seed is the same | |
# What not to do, and why --------------------------------------------------------------------------------------------- | |
# A typical data pipeline will involve many different components, each of which might have some mechanism to set a seed | |
# Let's assume we have some function from an open source library, called run_me(), which just sets its "own" seed, | |
# using the same method we used above: | |
def run_me(): | |
np.random.seed(123456) | |
# Seems harmless at first, but now, the result that we get from our own choose() depends on when we call run_me(), even | |
# if we set our "own" seed! | |
# Take 1: | |
# This does what we expect | |
np.random.seed(666) | |
print(np.random.choice(a=some_array, size=5)) # [ 3., 7., 10., 5., 4.] | |
run_me() | |
# Take 2: | |
# Despite setting the same random seed, this returns a _different_ set of values. | |
np.random.seed(666) | |
run_me() | |
print(np.random.choice(a=some_array, size=5)) # [2. 3. 2. 9. 1.] | |
# Take 3: | |
# Even worse, in this order, us setting the seed explicitly actually has no effect at all (i.e., even if we were trying | |
# to get a _new_ set of random numbers, we'd still get the same, because our seed is being overwritten): | |
np.random.seed(777) # Different seed | |
run_me() | |
print(np.random.choice(a=some_array, size=5)) # STILL [2. 3. 2. 9. 1.] | |
# This is obvious for us in this trivial example, but remember, a REAL project will involve dozens of different | |
# packages which we won't audit, and can't control | |
# The solution -------------------------------------------------------------------------------------------------------- | |
# The ideal solution is for EVERYONE to use their own sandbox of random number generators. We can't help if other | |
# packages keep using the global RNG, but we can avoid many problems if WE don't use it---and hopefully everyone in the | |
# community does the same, so eventually, nobody uses the global RNG 😅 | |
# This is accommodated in numpy via a Generator object, which you can practically think of a specific _instance_ of | |
# np.random | |
# You use np.random.default_rng() to create one, and use it as a drop-in replacement for np.random: | |
my_rng = np.random.default_rng(seed=666) | |
print(my_rng.choice(a=some_array, size=5)) # [9. 8. 3. 6. 1.] | |
# Different from what we got before, but still _consistent_ (i.e., we get the same value if we repeat, with seed) | |
my_rng = np.random.default_rng(seed=666) | |
print(my_rng.choice(a=some_array, size=5)) # [9. 8. 3. 6. 1.] | |
# And external functions that set the global seed will not affect our results | |
my_rng = np.random.default_rng(seed=666) | |
run_me() | |
print(my_rng.choice(a=some_array, size=5)) # STILL gets [9. 8. 3. 6. 1.] | |
# For official numpy documentation, see: | |
# https://numpy.org/doc/stable/reference/random/index.html#random-sampling-numpy-random |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment