-
-
Save rkern/9361aa15a28ae8c6dced01840209cdbb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import cycle | |
import re | |
from secrets import randbits | |
import numpy as np | |
cimport numpy as np | |
np.import_array() | |
DECIMAL_RE = re.compile(r'[0-9]+') | |
cdef uint32_t DEFAULT_POOL_SIZE = 4 # Appears also in docstring for pool_size | |
cdef uint32_t INIT_A = 0x43b0d7e5 | |
cdef uint32_t MULT_A = 0x931e8875 | |
cdef uint32_t INIT_B = 0x8b51f9dd | |
cdef uint32_t MULT_B = 0x58f38ded | |
cdef uint32_t MIX_MULT_L = 0xca01f9dd | |
cdef uint32_t MIX_MULT_R = 0x4973f715 | |
cdef uint32_t XSHIFT = np.dtype(np.uint32).itemsize * 8 // 2 | |
cdef uint32_t MASK32 = 0xFFFFFFFF | |
def _int_to_uint32_array(n): | |
arr = [] | |
if n < 0: | |
raise ValueError("expected non-negative integer") | |
if n == 0: | |
arr.append(np.uint32(n)) | |
if isinstance(n, np.unsignedinteger): | |
# Cannot do n & MASK32, convert to python int | |
n = int(n) | |
while n > 0: | |
arr.append(np.uint32(n & MASK32)) | |
n //= (2**32) | |
return np.array(arr, dtype=np.uint32) | |
def _coerce_to_uint32_array(x): | |
""" Coerce an input to a uint32 array. | |
If a `uint32` array, pass it through directly. | |
If a non-negative integer, then break it up into `uint32` words, lowest | |
bits first. | |
If a string starting with "0x", then interpret as a hex integer, as above. | |
If a string of decimal digits, interpret as a decimal integer, as above. | |
If a sequence of ints or strings, interpret each element as above and | |
concatenate. | |
Note that the handling of `int64` or `uint64` arrays are not just | |
straightforward views as `uint32` arrays. If an element is small enough to | |
fit into a `uint32`, then it will only take up one `uint32` element in the | |
output. This is to make sure that the interpretation of a sequence of | |
integers is the same regardless of numpy's default integer type, which | |
differs on different platforms. | |
Parameters | |
---------- | |
x : int, str, sequence of int or str | |
Returns | |
------- | |
seed_array : uint32 array | |
Examples | |
-------- | |
>>> import numpy as np | |
>>> from numpy.random.bit_generator import _coerce_to_uint32_array | |
>>> _coerce_to_uint32_array(12345) | |
array([12345], dtype=uint32) | |
>>> _coerce_to_uint32_array('12345') | |
array([12345], dtype=uint32) | |
>>> _coerce_to_uint32_array('0x12345') | |
array([74565], dtype=uint32) | |
>>> _coerce_to_uint32_array([12345, '67890']) | |
array([12345, 67890], dtype=uint32) | |
>>> _coerce_to_uint32_array(np.array([12345, 67890], dtype=np.uint32)) | |
array([12345, 67890], dtype=uint32) | |
>>> _coerce_to_uint32_array(np.array([12345, 67890], dtype=np.int64)) | |
array([12345, 67890], dtype=uint32) | |
>>> _coerce_to_uint32_array([12345, 0x10deadbeef, 67890, 0xdeadbeef]) | |
array([ 12345, 3735928559, 16, 67890, 3735928559], | |
dtype=uint32) | |
>>> _coerce_to_uint32_array(1234567890123456789012345678901234567890) | |
array([3460238034, 2898026390, 3235640248, 2697535605, 3], | |
dtype=uint32) | |
""" | |
if isinstance(x, np.ndarray) and x.dtype == np.dtype(np.uint32): | |
return x.copy() | |
elif isinstance(x, str): | |
if x.startswith('0x'): | |
x = int(x, base=16) | |
elif DECIMAL_RE.match(x): | |
x = int(x) | |
else: | |
raise ValueError("unrecognized seed string") | |
if isinstance(x, (int, np.integer)): | |
return _int_to_uint32_array(x) | |
elif isinstance(x, (float, np.inexact)): | |
raise TypeError('seed must be integer') | |
else: | |
if len(x) == 0: | |
return np.array([], dtype=np.uint32) | |
# Should be a sequence of interpretable-as-ints. Convert each one to | |
# a uint32 array and concatenate. | |
subseqs = [_coerce_to_uint32_array(v) for v in x] | |
return np.concatenate(subseqs) | |
cdef uint32_t hashmix(uint32_t value, uint32_t * hash_const): | |
# We are modifying the multiplier as we go along, so it is input-output | |
value ^= hash_const[0] | |
hash_const[0] *= MULT_A | |
value *= hash_const[0] | |
value ^= value >> XSHIFT | |
return value | |
cdef uint32_t mix(uint32_t x, uint32_t y): | |
cdef uint32_t result = (MIX_MULT_L * x - MIX_MULT_R * y) | |
result ^= result >> XSHIFT | |
return result | |
cdef class SplitSeed(): | |
""" | |
SplitSeed(entropy=None, *, pool_size=4) | |
`SplitSeed` mixes sources of entropy in a reproducible way to set the | |
initial state for independent and very probably non-overlapping | |
BitGenerators. | |
Once the `SplitSeed` is instantiated, you can call the `generate_state` | |
method to get an appropriately sized seed. Calling `split(n) <split>` will | |
create ``n`` SplitSeeds that can be used to seed independent | |
BitGenerators, i.e. for different threads. Unlike `SeedSequence.spawn`, | |
calling `split(n) <split>` multiple times will return the *same* results | |
for a more pure functional API. | |
Parameters | |
---------- | |
entropy : {None, int, sequence[int]}, optional | |
The entropy for initially creating a `SplitSeed`. If `pool` is | |
provided, this will be stored but not used, and will simply reflect the | |
value that was used at the root of the split tree. The splitting path | |
is not stored. | |
pool_size : {int}, optional | |
Size of the pooled entropy to store. Default is 4 to give a 128-bit | |
entropy pool. 8 (for 256 bits) is another reasonable choice if working | |
with larger PRNGs, but there is very little to be gained by selecting | |
another value. | |
pool : uint32 array, optional | |
The internal hash pool. Only pass this if reconstructing a `SplitSeed` | |
from a serialized form. | |
hash_const : uint32, optional | |
The internal hash constant for mixing in new entropy. Only pass this if | |
reconstructing a `SplitSeed` from a serialized form. | |
""" | |
def __init__(self, entropy=None, *, pool_size=DEFAULT_POOL_SIZE, pool=None, | |
hash_const=None): | |
# FIXME: ignore this for now so we can experiment with smaller pool | |
# sizes. | |
# if pool_size < DEFAULT_POOL_SIZE: | |
# raise ValueError("The size of the entropy pool should be at least " | |
# f"{DEFAULT_POOL_SIZE}") | |
if entropy is None: | |
entropy = randbits(pool_size * 32) | |
elif not isinstance(entropy, (int, np.integer, list, tuple, range, | |
np.ndarray)): | |
raise TypeError('SeedSequence expects int or sequence of ints for ' | |
'entropy not {}'.format(entropy)) | |
self.entropy = entropy | |
self.pool_size = pool_size | |
if hash_const is None: | |
hash_const = INIT_A | |
self.hash_const = hash_const | |
if pool is None: | |
self.pool = np.zeros(pool_size, dtype=np.uint32) | |
self.mix_entropy(self.pool, self.get_assembled_entropy()) | |
else: | |
self.pool = pool.copy() | |
def __repr__(self): | |
lines = [ | |
f'{type(self).__name__}(', | |
f' entropy={self.entropy!r},', | |
f' pool={self.pool!r},', | |
f' hash_const={self.hash_const!r},', | |
] | |
# Omit some entries if they are left as the defaults in order to | |
# simplify things. | |
if self.pool_size != DEFAULT_POOL_SIZE: | |
lines.append(f' pool_size={self.pool_size!r},') | |
lines.append(')') | |
text = '\n'.join(lines) | |
return text | |
@property | |
def state(self): | |
return {k:getattr(self, k) for k in | |
['entropy', 'pool_size', 'pool', | |
'hash_const'] | |
if getattr(self, k) is not None} | |
cdef mix_entropy(self, np.ndarray[np.npy_uint32, ndim=1] mixer, | |
np.ndarray[np.npy_uint32, ndim=1] entropy_array): | |
""" Mix in the given entropy to mixer. | |
Parameters | |
---------- | |
mixer : 1D uint32 array, modified in-place | |
entropy_array : 1D uint32 array | |
""" | |
cdef uint32_t hash_const[1] | |
hash_const[0] = INIT_A | |
# Add in the entropy up to the pool size. | |
for i in range(len(mixer)): | |
if i < len(entropy_array): | |
mixer[i] = hashmix(entropy_array[i], hash_const) | |
else: | |
# Our pool size is bigger than our entropy, so just keep | |
# running the hash out. | |
mixer[i] = hashmix(0, hash_const) | |
# Mix all bits together so late bits can affect earlier bits. | |
for i_src in range(len(mixer)): | |
for i_dst in range(len(mixer)): | |
if i_src != i_dst: | |
mixer[i_dst] = mix(mixer[i_dst], | |
hashmix(mixer[i_src], hash_const)) | |
# Add any remaining entropy, mixing each new entropy word with each | |
# pool word. | |
for i_src in range(len(mixer), len(entropy_array)): | |
for i_dst in range(len(mixer)): | |
mixer[i_dst] = mix(mixer[i_dst], | |
hashmix(entropy_array[i_src], hash_const)) | |
self.hash_const = hash_const[0] | |
cdef mix_split_key(self, uint32_t i_split): | |
cdef int i_dst | |
cdef uint32_t hash_const[1] | |
cdef np.ndarray[np.npy_uint32, ndim=1] mixer = self.pool | |
hash_const[0] = self.hash_const | |
for i_dst in range(len(mixer)): | |
mixer[i_dst] = mix(mixer[i_dst], | |
hashmix(i_split, hash_const)) | |
self.hash_const = hash_const[0] | |
cpdef get_assembled_entropy(self): | |
""" Convert and assemble all entropy sources into a uniform uint32 | |
array. | |
Returns | |
------- | |
entropy_array : 1D uint32 array | |
""" | |
# Convert run-entropy and the spawn key into uint32 | |
# arrays and concatenate them. | |
# We MUST have at least some run-entropy. The others are optional. | |
assert self.entropy is not None | |
run_entropy = _coerce_to_uint32_array(self.entropy) | |
if len(run_entropy) < self.pool_size: | |
# Explicitly fill out the entropy with 0s to the pool size to avoid | |
# conflict with spawn keys. | |
diff = self.pool_size - len(run_entropy) | |
run_entropy = np.concatenate( | |
[run_entropy, np.zeros(diff, dtype=np.uint32)]) | |
entropy_array = run_entropy | |
return entropy_array | |
@np.errstate(over='ignore') | |
def generate_state(self, n_words, dtype=np.uint32): | |
""" | |
generate_state(n_words, dtype=np.uint32) | |
Return the requested number of words for PRNG seeding. | |
A BitGenerator should call this method in its constructor with | |
an appropriate `n_words` parameter to properly seed itself. | |
Parameters | |
---------- | |
n_words : int | |
dtype : np.uint32 or np.uint64, optional | |
The size of each word. This should only be either `uint32` or | |
`uint64`. Strings (`'uint32'`, `'uint64'`) are fine. Note that | |
requesting `uint64` will draw twice as many bits as `uint32` for | |
the same `n_words`. This is a convenience for `BitGenerator`s that | |
express their states as `uint64` arrays. | |
Returns | |
------- | |
state : uint32 or uint64 array, shape=(n_words,) | |
""" | |
cdef uint32_t hash_const = INIT_B | |
cdef uint32_t data_val | |
out_dtype = np.dtype(dtype) | |
if out_dtype == np.dtype(np.uint32): | |
pass | |
elif out_dtype == np.dtype(np.uint64): | |
n_words *= 2 | |
else: | |
raise ValueError("only support uint32 or uint64") | |
state = np.zeros(n_words, dtype=np.uint32) | |
src_cycle = cycle(self.pool) | |
for i_dst in range(n_words): | |
data_val = next(src_cycle) | |
data_val ^= hash_const | |
hash_const *= MULT_B | |
data_val *= hash_const | |
data_val ^= data_val >> XSHIFT | |
state[i_dst] = data_val | |
if out_dtype == np.dtype(np.uint64): | |
# For consistency across different endiannesses, view first as | |
# little-endian then convert the values to the native endianness. | |
state = state.astype('<u4').view('<u8').astype(np.uint64) | |
return state | |
def split(self, n_children): | |
""" | |
split(n_children) | |
Split off a number of child `SplitSeed` s by mixing in different | |
numbers into the entropy pool for each sub-stream. | |
Unlike `SeedSequence.spawn`, this method is idempotent. Calling it | |
multiple times will return the same `SplitSeed` values. | |
Parameters | |
---------- | |
n_children : int | |
Returns | |
------- | |
seqs : list of `SplitSeed` s | |
""" | |
cdef uint32_t i_split | |
cdef SplitSeed ss | |
seqs = [] | |
for i_split in range(n_children): | |
ss = type(self)( | |
self.entropy, | |
pool=self.pool, | |
hash_const=self.hash_const, | |
pool_size=self.pool_size, | |
) | |
ss.mix_split_key(i_split) | |
seqs.append(ss) | |
return seqs | |
np.random.bit_generator.ISeedSequence.register(SplitSeed) | |
cdef inline uint32_t rotate_left32(uint32_t x, uint32_t r): | |
return (x << r) | (x >> (32 - r)) | |
cdef apply_round(uint32_t *x1, uint32_t *x2, uint32_t r): | |
cdef uint32_t y1, y2 | |
y1 = x1[0] | |
y2 = x2[0] | |
y1 = y1 + y2 | |
y2 = rotate_left32(y2, r) | |
y2 = y1 ^ y2 | |
x1[0] = y1 | |
x2[0] = y2 | |
cpdef threefry2x32(uint32_t key1, uint32_t key2, uint32_t x1, uint32_t x2): | |
cdef uint32_t key3 = key1 ^ key2 ^ <uint32_t>(0x1BD11BDA) | |
x1 += key1 | |
x2 += key2 | |
apply_round(&x1, &x2, 13) | |
apply_round(&x1, &x2, 15) | |
apply_round(&x1, &x2, 26) | |
apply_round(&x1, &x2, 6) | |
x1 += key2 | |
x2 += key3 + <uint32_t>(1) | |
apply_round(&x1, &x2, 17) | |
apply_round(&x1, &x2, 29) | |
apply_round(&x1, &x2, 16) | |
apply_round(&x1, &x2, 24) | |
x1 += key3 | |
x2 += key1 + <uint32_t>(2) | |
apply_round(&x1, &x2, 13) | |
apply_round(&x1, &x2, 15) | |
apply_round(&x1, &x2, 26) | |
apply_round(&x1, &x2, 6) | |
x1 += key1 | |
x2 += key2 + <uint32_t>(3) | |
apply_round(&x1, &x2, 17) | |
apply_round(&x1, &x2, 29) | |
apply_round(&x1, &x2, 16) | |
apply_round(&x1, &x2, 24) | |
x1 += key2 | |
x2 += key3 + <uint32_t>(4) | |
apply_round(&x1, &x2, 13) | |
apply_round(&x1, &x2, 15) | |
apply_round(&x1, &x2, 26) | |
apply_round(&x1, &x2, 6) | |
x1 += key3 | |
x2 += key1 + <uint32_t>(5) | |
return (x1, x2) | |
cpdef iterate_jax_key(np.ndarray[np.uint32_t, ndim=1] key): | |
"""Faithful implementation of ``jax.random.split(key)[0]`` | |
""" | |
cdef uint32_t key1, key2 | |
key1, key2 = key | |
out1, _ = threefry2x32(key1, key2, 0, 2) | |
out2, _ = threefry2x32(key1, key2, 1, 3) | |
return np.array([out1, out2], dtype=np.uint32) | |
cpdef fixed_iterate_jax_key(np.ndarray[np.uint32_t, ndim=1] key): | |
"""JAX key iteration if the ``threefry_random_bits()`` quirk were fixed. | |
""" | |
cdef uint32_t key1, key2 | |
key1, key2 = key | |
out1, out2 = threefry2x32(key1, key2, 0, 0) | |
return np.array([out1, out2], dtype=np.uint32) | |
cpdef bijective_iterate_jax_key(np.ndarray[np.uint32_t, ndim=1] key): | |
"""Putative bijective key splitting (left branch). | |
""" | |
cdef uint32_t key1, key2 | |
key1, key2 = key | |
out1, out2 = threefry2x32(0, 0, key1, key2) | |
return np.array([out1, out2], dtype=np.uint32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment