Created
June 1, 2022 17:40
-
-
Save rgommers/997ab5a287e9a19c345771d2f5712574 to your computer and use it in GitHub Desktop.
Comparing JAX and NumPy APIs for random number generation - serial and parallel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implement `jax.random` APIs with NumPy, and `numpy.random` APIs with JAX. | |
The purpose of this is to be able to compare APIs more easily, and clarify | |
where they are and aren't similar. | |
""" | |
import secrets | |
import multiprocessing | |
import numpy as np | |
import jax | |
USE_FIXED_SEED = False | |
if USE_FIXED_SEED: | |
seed = 38968222334307 | |
else: | |
# Generate a random high-entropy seed for use in the below examples | |
# jax.random.PRNGKey doesn't accept None to do this automatically | |
seed = secrets.randbits(32) # JAX can't deal with >32-bits | |
# NumPy serial | |
rng = np.random.default_rng(seed=seed) | |
vals = rng.uniform(size=3) | |
val = rng.uniform(size=1) | |
# NumPy parallel | |
sseq = np.random.SeedSequence(entropy=seed) | |
child_seeds = sseq.spawn(4) | |
rngs = [np.random.default_rng(seed=s) for s in child_seeds] | |
def use_rngs_numpy(rng): | |
vals = rng.uniform(size=3) | |
val = rng.uniform(size=1) | |
print(vals, val) | |
def main_numpy(): | |
with multiprocessing.Pool(processes=4) as pool: | |
pool.map(use_rngs_numpy, rngs) | |
# JAX serial (also auto-parallelizes fine by design) | |
key = jax.random.PRNGKey(seed) | |
key, subkey = jax.random.split(key) # this one could be left out, but best practice is probably to always use `split` first | |
vals = jax.random.uniform(subkey, shape=(3,)) | |
key, subkey = jax.random.split(key) # don't forget this! | |
val = jax.random.uniform(subkey, shape=(1,)) | |
# JAX parallel with multiprocessing | |
def use_rngs_jax(key): | |
key, subkey = jax.random.split(key) | |
vals = jax.random.uniform(subkey, shape=(3,)) | |
key, subkey = jax.random.split(key) | |
val = jax.random.uniform(subkey, shape=(1,)) | |
print(vals, val) | |
def main_jax(): | |
key = jax.random.PRNGKey(seed) | |
key, *subkeys = jax.random.split(key, 5) # gotcha: "5" gives us 4 subkeys | |
with multiprocessing.Pool(processes=4) as pool: | |
pool.map(use_rngs_jax, subkeys) | |
# An API matching JAX on top of `numpy.random` | |
############################################## | |
def PRNGKey(seed): | |
""" | |
Create a key from a seed. `seed` must be a 32-bit (or 64-bit?) integer. | |
""" | |
# Note: selecting a non-default PRNG algorithm is done via a global config | |
# flag (not good, should be a keyword or similar ...) | |
seed = np.random.SeedSequence(seed) | |
rng = np.random.default_rng(seed) | |
key = (seed, rng) | |
return key | |
def split(key, num=2): | |
""" | |
Parameters | |
---------- | |
key : tuple | |
Size-2 tuple, the first element a `SeedSequence` instance, the second | |
containing the algorithm selector. | |
num : int, optional | |
The number of keys to produce (default: 2). | |
Returns | |
------- | |
keys : tuple of 2-tuples | |
`num` number of keys (each key being a 2-tuple) | |
""" | |
seed, rng = key | |
child_seeds = seed.spawn(num) | |
keys = ((s, rng) for s in child_seeds) | |
return keys | |
def uniform(key, shape=(), dtype=np.float64, minval=0.0, maxval=1.0): | |
seed, rng = key | |
# Creating a new Generator instance from an old one with the same | |
# underlying BitGenerator type requires using non-public API: | |
rng = np.random.Generator(rng._bit_generator.__class__(seed)) | |
return rng.uniform(low=minval, high=maxval, size=shape).astype(dtype) | |
def use_jaxlike_api(key=None): | |
if key is None: | |
key = PRNGKey(seed) | |
key, subkey = split(key) | |
vals = uniform(subkey, shape=(3,)) | |
key, subkey = split(key) # don't forget this! | |
val = uniform(subkey, shape=(1,)) | |
print(vals, val) | |
def use_jaxlike_api_mp(): | |
key = PRNGKey(seed) | |
key, *subkeys = split(key, 5) | |
with multiprocessing.Pool(processes=4) as pool: | |
pool.map(use_jaxlike_api, subkeys) | |
if __name__ == '__main__': | |
# JAX does not work with the default `fork` (due to internal threading) | |
multiprocessing.set_start_method('forkserver') | |
print('\nNumPy with multiprocessing:\n') | |
main_numpy() | |
print('\n\nJAX with multiprocessing:\n') | |
main_jax() | |
print('\n\nUse JAX-like API (serial):\n') | |
use_jaxlike_api() | |
print('\n\nUse JAX-like API (multiprocessing):\n') | |
use_jaxlike_api_mp() | |
# Gotcha with seed creation: | |
""" | |
In [24]: seed = secrets.randbits(64) | |
In [25]: jax.random.PRNGKey(seed) | |
--------------------------------------------------------------------------- | |
OverflowError Traceback (most recent call last) | |
<ipython-input-25-7a8d328c270c> in <module> | |
----> 1 jax.random.PRNGKey(seed) | |
~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed) | |
57 # Explicitly cast to int64 for JIT invariance of behavior on large ints. | |
58 if isinstance(seed, int): | |
---> 59 seed = np.int64(seed) | |
60 # Converting to jnp.array may truncate bits when jax_enable_x64=False, but this | |
61 # is necessary for the sake of JIT invariance of the result for such values. | |
OverflowError: Python int too large to convert to C long | |
In [26]: seed = secrets.randbits(128) | |
In [27]: jax.random.PRNGKey(seed) | |
--------------------------------------------------------------------------- | |
TypeError Traceback (most recent call last) | |
<ipython-input-27-7a8d328c270c> in <module> | |
----> 1 jax.random.PRNGKey(seed) | |
~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed) | |
53 raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.") | |
54 if not np.issubdtype(np.result_type(seed), np.integer): | |
---> 55 raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}") | |
56 | |
57 # Explicitly cast to int64 for JIT invariance of behavior on large ints. | |
TypeError: PRNGKey seed must be an integer; got 67681183633192462759155065893448052088 | |
In [28]: seed = secrets.randbits(64) | |
In [29]: jax.random.PRNGKey(seed) | |
--------------------------------------------------------------------------- | |
OverflowError Traceback (most recent call last) | |
<ipython-input-29-7a8d328c270c> in <module> | |
----> 1 jax.random.PRNGKey(seed) | |
~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed) | |
57 # Explicitly cast to int64 for JIT invariance of behavior on large ints. | |
58 if isinstance(seed, int): | |
---> 59 seed = np.int64(seed) | |
60 # Converting to jnp.array may truncate bits when jax_enable_x64=False, but this | |
61 # is necessary for the sake of JIT invariance of the result for such values. | |
OverflowError: Python int too large to convert to C long | |
In [30]: seed = secrets.randbits(32) | |
In [31]: jax.random.PRNGKey(seed) | |
Out[31]: DeviceArray([ 0, 3279739543], dtype=uint32) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment