Skip to content

Instantly share code, notes, and snippets.

@lincoln-lm
Created October 19, 2024 10:17
Show Gist options
  • Save lincoln-lm/a7be1e81171218775399dda3da963030 to your computer and use it in GitHub Desktop.
Save lincoln-lm/a7be1e81171218775399dda3da963030 to your computer and use it in GitHub Desktop.
Finds rng states of the form (seed_0, 0x82A2B175229D6A5B) that produce a given sequence of 64+ rand(2) observations when rand(121)s happen before each.
"""
Finds rng states of the form (seed_0, 0x82A2B175229D6A5B) that produce a given sequence
of 64+ rand(2) observations when rand(121)s happen before each.
The probability of an arbitrary sequence can be solved for an initial state is roughly
P(X <= max jump count) for X ~ Bin(64, 7/128) (the fact that rand(121) can reject more than once
makes this not exactly correct)
The amount of work needed to solve the sequence scales with CR(64, max jump count)
(combination with replacement; technically sum{0 <= i <= max jump count} CR(64, i) but terms before
the last are comparatively small).
"""
from itertools import combinations_with_replacement
import numpy as np
import numba
from scipy.stats import binom
from scipy.special import factorial
from numba_progress import ProgressBar
from numba_progress.progress import progressbar_type
from numba_progress.numba_atomic import atomic_add
if __name__ == "__main__":
print("Compiling...")
from numba_pokemon_prngs.xorshift import Xoroshiro128PlusRejection
NUM_THREADS = numba.config.NUMBA_NUM_THREADS
def find_working_example(max_jumps: int):
"""Find an example sequence & state that has max_jumps or fewer jumps"""
rng = Xoroshiro128PlusRejection(0, 0)
observations = np.empty(64, np.uint8)
while True:
jumps = []
seed_0 = np.random.randint(0, 1 << 64, dtype=np.uint64)
seed_1 = 0x82A2B175229D6A5B
rng.re_init(seed_0, seed_1)
for i in range(64):
# equivalent to rng.next_rand(121)
while rng.next_rand(128) >= 121:
jumps.append(i)
if len(jumps) > max_jumps:
break
observations[i] = rng.next_rand(2)
if len(jumps) <= max_jumps:
break
return seed_0, seed_1, jumps, observations
@numba.njit
def test_sequence(
rng: Xoroshiro128PlusRejection,
observations,
jumps,
result_count,
results
):
"""Test a given combination to see if it produces the expected sequence"""
# construct a matrix that maps from an initial state to the observations
mat = np.zeros((64, 64), np.uint8)
for bit in range(64):
seed_0 = 1 << bit if bit < 64 else 0
rng.re_init(seed_0, 0)
for i in range(64):
rng.next()
for j in jumps:
if i == j:
rng.next()
mat[bit, i] = rng.next_rand(2)
# construct vector that stores the influence of 0x82A2B175229D6A5B on the observations
const_influence = np.empty(64, np.uint8)
rng.re_init(0, np.uint64(0x82A2B175229D6A5B))
for i in range(64):
rng.next()
for j in jumps:
if i == j:
rng.next()
const_influence[i] = rng.next_rand(2)
# invert the matrix to map from observations to the state
inverse = np.empty(64, np.uint64)
# identity
for bit in range(64):
inverse[bit] = 1 << bit
rank = 0
pivots = []
for col in range(64):
for row in range(rank, 64):
if mat[row, col]:
for other_row in range(64):
if (other_row != row) and mat[other_row, col]:
mat[other_row] ^= mat[row]
inverse[other_row] ^= inverse[row]
temp = np.copy(mat[row])
mat[row] = mat[rank]
mat[rank] = temp
temp = inverse[row]
inverse[row] = inverse[rank]
inverse[rank] = temp
pivots.append(col)
rank += 1
break
# store the nullbasis in the event that the observations are not determinantive
nullbasis = np.copy(inverse[rank:])
# undo pivots
for i in range(rank - 1, -1, -1):
pivot = pivots[i]
temp = inverse[i]
inverse[i] = inverse[pivot]
inverse[pivot] = temp
# (observations - const_influence) @ inverse
principal_result = np.uint64(0)
for i in range(64):
if observations[i] ^ const_influence[i]:
principal_result ^= inverse[i]
# loop over other solutions
for i in range(1 << len(nullbasis)):
result = principal_result
for bit in range(64):
if i == 0:
break
if i & 1:
result ^= nullbasis[bit]
i >>= 1
# test result
rng.re_init(result, np.uint64(0x82A2B175229D6A5B))
valid = True
for observation in observations:
rng.next_rand(121)
valid &= (rng.next() & 1) == observation
if not valid:
break
if valid:
results[atomic_add(result_count, 0, 1)] = result
print(f"\nFound state: {result} jumps: {', '.join(map(str, jumps))}")
@numba.njit(
numba.uint64[:](numba.uint8[:], numba.int8[:, :], progressbar_type),
nogil=True,
parallel=True
)
def search(observations, combinations, progress_proxy):
"""Test all given combinations of jumps"""
result_count = np.zeros(1, np.uint64)
# huge overestimation of the number of results
results = np.zeros(0x10000, np.uint64)
for thread_i in numba.prange(NUM_THREADS):
rng = Xoroshiro128PlusRejection(0, 0)
ofs = combinations.shape[0] // NUM_THREADS * thread_i
end = combinations.shape[0] // NUM_THREADS * (thread_i + 1)
if thread_i == NUM_THREADS - 1:
end = combinations.shape[0]
for i in range(ofs, end):
test_sequence(rng, observations, combinations[i], result_count, results)
progress_proxy.update(1)
return results[:atomic_add(result_count, 0, 0)]
def build_combinations(max_jumps):
"""Generate all possible combinations of jumps"""
combinations = []
for jump_count in range(max_jumps + 1):
for jumps in combinations_with_replacement(range(64), jump_count):
combinations.append(jumps + tuple(-1 for _ in range(max_jumps - jump_count)))
return np.array(combinations, dtype=np.int8)
def calc_p_and_work(max_jumps: int):
"""Calculate an estimation of the probability of success & work needed for a given max_jumps"""
p = binom.cdf(max_jumps, 64, 7 / 128)
r = np.arange(max_jumps+1, dtype=np.uint64)
work = round(np.sum(factorial(64 + r - 1) / (factorial(r) * factorial(64 - 1))))
return p, work
if __name__ == "__main__":
for max_jumps in range(10):
p, work = calc_p_and_work(max_jumps)
# rough estimation for work per sec per thread
est_time = work / (1542 * NUM_THREADS)
print(f"Max jumps: {max_jumps} Success rate: {p*100:.2f}% Work: {work} Estimated time:{est_time:.1f}s")
max_jumps = int(input("Max jumps: "))
# seed_0, seed_1, jumps, observations = find_working_example(max_jumps)
# print(f"Sample state: {seed_0=:X} {jumps=}")
observations = tuple(o == "p" for o in input("p = Sleeping peacefully, d = Deep sleep: "))
assert len(observations) >= 64, f"Not enough observations: {len(observations)}<64"
observations = np.array(observations, np.uint8)
combinations = build_combinations(max_jumps)
print(f"Checking {combinations.shape[0]} combinations on {NUM_THREADS} threads")
with ProgressBar(total=combinations.shape[0]) as progress_proxy:
results = set(search(observations, combinations, progress_proxy))
print(f"Found {len(results)} unique initial seed(s)")
for result in results:
print(f"{result:08X}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment