Created
August 2, 2023 22:47
-
-
Save ooovi/529c00fc8a7eafd068cd076b78fc424e to your computer and use it in GitHub Desktop.
a script to generate test vectors fir the discrete gaussian sampler using sha3 seeded randomness
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
import json | |
import os | |
import sys | |
from prg import PrgSha3 | |
from fractions import Fraction | |
# discrete gaussian sampler from here: | |
# https://github.com/IBM/discrete-gaussian-differential-privacy | |
from discretegauss import (sample_dgauss, sample_bernoulli_exp) | |
dst = b'' | |
seed = bytes(0 for i in range(PrgSha3.SEED_SIZE)) | |
rng = PrgSha3(seed,b'',b'') | |
test_vector = { | |
'seed': seed.hex(), | |
'std_num': None, # set below | |
'std_denom': None, # set below | |
'samples': None, # set below | |
} | |
std = Fraction(sys.argv[1]) | |
var = std**2 | |
test_vector['std_num'] = std.numerator | |
test_vector['std_denom'] = std.denominator | |
test_vector['samples'] = [sample_dgauss(var, rng) for _ in range(1,50)] | |
os.system('mkdir -p test_vec/') | |
with open('test_vec/discrete_gauss_'+sys.argv[1]+'.json', 'w') as f: | |
json.dump(test_vector, f, indent=4, sort_keys=True) | |
f.write('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Crypto.Hash import cSHAKE128 | |
class PrgSha3: | |
"""PRG based on SHA-3 (cSHAKE128).""" | |
# Associated parameters | |
SEED_SIZE = 16 | |
def __init__(self, seed, dst, binder): | |
# `dst` is used as the customization string; `seed || binder` is | |
# used as the main input string. | |
self.shake = cSHAKE128.new(custom=dst) | |
self.shake.update(seed) | |
self.shake.update(binder) | |
def next(self, length: int) -> bytes: | |
return self.shake.read(length) | |
# generate random int in range [0,m) emulating rust BigUint uniform sampling | |
def randrange(self, m: int) -> int: | |
return gen_biguint_range(self, 0, m) | |
####################################################################### | |
# emulate rust BigUint uniform sampling with PrgSha3. original sampler can be found here: | |
# https://docs.rs/num-bigint/0.4.3/num_bigint/struct.UniformBigUint.html# | |
# we simulate u32's with ints | |
# i.e., the return type should be `list[u32]` | |
def fill_u32_array(rng: PrgSha3, data_len: int) -> list[int]: | |
data = [] | |
for _ in range(data_len): | |
data.append(int.from_bytes(PrgSha3.next(rng,4), 'little')) | |
return data | |
# converting list[u32] into "BigUint" | |
# | |
# note: index 0 contains the least-significant-digit | |
def u32_array_as_int(xs: list[int]) -> int: | |
result = 0 | |
for i in range(len(xs)): | |
result += xs[i] << (32*i) | |
return result | |
# following: | |
# https://github.com/rust-num/num-bigint/blob/6f2b8e0fc218dbd0f49bebb8db2d1a771fe6bafa/src/bigrand.rs#L40 | |
def gen_bits(rng: PrgSha3, data_len: int, rem: int) -> list[int]: | |
data = fill_u32_array(rng, data_len) | |
# if we have a non-full final u32-digit (containing rem bits), | |
# then we forget 32-rem bits by right shifting | |
if rem > 0: | |
last = len(data) - 1 | |
data[last] >>= 32 - rem | |
return data | |
# following: | |
# https://github.com/rust-num/num-bigint/blob/6f2b8e0fc218dbd0f49bebb8db2d1a771fe6bafa/src/bigrand.rs#L51 | |
# | |
# the result type is BigUint | |
def gen_biguint(rng: PrgSha3, bit_size: int) -> int: | |
digits, rem = divmod(bit_size, 32) | |
data_len = digits + (rem > 0) | |
data = gen_bits(rng, data_len, rem) | |
return u32_array_as_int(data) | |
# following: | |
# https://github.com/rust-num/num-bigint/blob/6f2b8e0fc218dbd0f49bebb8db2d1a771fe6bafa/src/bigrand.rs#L111 | |
def gen_biguint_below(rng: PrgSha3, bound: int) -> int: | |
bits = bound.bit_length() | |
while True: | |
n = gen_biguint(rng, bits) | |
if n < bound: | |
return n | |
# following: | |
# https://github.com/rust-num/num-bigint/blob/6f2b8e0fc218dbd0f49bebb8db2d1a771fe6bafa/src/bigrand.rs#L122 | |
def gen_biguint_range(rng: PrgSha3, lbound: int, ubound: int) -> int: | |
if lbound == 0: | |
return gen_biguint_below(rng, ubound) | |
else: | |
return lbound + gen_biguint_below(rng, ubound - lbound) | |
####################################################################### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment