Created
March 15, 2021 18:20
-
-
Save ItsDrike/bde2f77efa0efedb070a0dffd0a28e24 to your computer and use it in GitHub Desktop.
Shamir Secret Sharing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
from contextlib import redirect_stdout | |
import random | |
import secrets | |
import decimal | |
from decimal import Decimal, InvalidOperation, getcontext | |
# Set decimal precision high enough to compute the interpolation | |
# without any loss in floating point operations | |
getcontext().prec = 500 | |
class Shamir: | |
@classmethod | |
def make_shares( | |
cls, | |
secret_key: int, | |
share_amt: int, | |
threshold_amt: int, | |
field_bits: int = 64 | |
) -> list[tuple[int, int]]: | |
""" | |
This generates `share_amt` of keys, which individually can't | |
be meaningfully used to reconstruct secret_key, and aren't directly | |
connected with it, but when `threshold_amt` of them are combined together, | |
the true `secret_key` can be reconstructed. | |
This works using Shamir Secret Sharing (SSS) algorithm, which generates | |
a secret polynomial with a constant value of our `secret_key`, we can then | |
recover this key from n or more points belonging to this polynomial. | |
`field_bits` is the amount of bits used for random generation of the | |
coefficients for given polynomial, 64bit key values are usually enough here. | |
Note that if this value is too small, we wont get far from the original value | |
on the graph, and each share will contain the `y` value that may be close to | |
the original key value, for that reason, we shouldn't make this value too small. | |
A list of points (x, y) belonging to our polynomial is returned, which | |
represent the individual shares, from which we can easily reconstruct | |
the coefficients and determine the key constant. | |
""" | |
coefficients = cls._make_coefficients( | |
degree=threshold_amt, | |
constant=secret_key, | |
field_bits=field_bits | |
) | |
# We can't generate n shares with less than n possibilities for points | |
if 2**field_bits < share_amt: | |
raise ValueError( | |
f"Attempted to generate {share_amt} of shares, \n" | |
f"with only {2**field_bits} ({field_bits=}) possibilities for points" | |
) | |
# Generate the shares | |
# share is a combination of (x, y) points on graph of our polynomial | |
shares = [] | |
used = set() | |
for _ in range(share_amt): | |
# Always use different points, we don't want same shares | |
while True: | |
x = secrets.randbits(field_bits) | |
if x not in used: | |
used.add(x) | |
break | |
y = cls._evaluate_polynomial(x, coefficients) | |
shares.append((x, y)) | |
return shares | |
@classmethod | |
def reconstruct_key(cls, shares: list[tuple[int, int]]) -> int: | |
""" | |
Reconstruct the secret key from individual `shares`. | |
This works by determining the coefficients of a polynomial, | |
that contains the points defined in the `shares` list, with each | |
share being a single (x, y) point on the graph of our polynomial | |
where the constant value of this polynomial is equal to our key. | |
This method of obtaining the polynomial from n given points is | |
called lagrange interpolation. | |
""" | |
try: | |
key_decimal = cls._interpolate(shares, 0) | |
except ValueError: | |
raise ValueError("All shares must be unique, duplicate share found.") | |
try: | |
return int(round(key_decimal, 0)) | |
except InvalidOperation: | |
raise RuntimeError( | |
"Precision for decimal objects isn't high enough to compute exact interpolation" | |
) | |
@staticmethod | |
def _interpolate(points: list[tuple[int, int]], x: int) -> Decimal: | |
""" | |
Given the list of `points` belonging to certain polynomial function, | |
evaluate this function at given `x` and return the corresponding y value. | |
""" | |
result = Decimal(0) | |
for i, (xi, yi) in enumerate(points): | |
i, xi, yi = Decimal(i), Decimal(xi), Decimal(yi) | |
# Compute individual terms of the lagrange's interpolation formula | |
term = yi | |
for j, (xj, _yj) in enumerate(points): | |
i, xi, _yj = Decimal(i), Decimal(xi), Decimal(_yj) | |
if i == j: | |
continue | |
try: | |
term *= (x - xj) / (xi - xj) | |
except decimal.DivisionByZero: | |
raise ValueError("All passed points must be unique") | |
result += term | |
return result | |
@staticmethod | |
def _evaluate_polynomial(x: int, coefficients: list[int]) -> int: | |
""" | |
This generates a single point on the graph of given polynomial | |
in `x`. The polynomial is given by the list of `coefficients`. | |
""" | |
point = 0 | |
# Loop through reversed list, so that indices from enumerate match the | |
# actual coefficient indices | |
for coefficient_index, coefficient_value in enumerate(coefficients[::-1]): | |
point += x ** coefficient_index * coefficient_value | |
return point | |
@staticmethod | |
def _make_coefficients(degree: int, constant: int, field_bits: int) -> list[int]: | |
""" | |
Generate a list of coefficients for a polynomial with degree of `degree` - 1, | |
whose constant is `constant`. The maximum value for each of these coefficients | |
is `field_size`. | |
For example with a 3rd degree coefficient like this: | |
3x^3 + 4x^2 + 18x + 554 | |
554 is the secret, and the polynomial degree + 1 is how many points | |
are needed, to recover this secret. With 3rd degree polynomial, | |
minimum of any 4 points, belonging to this polynomial are needed | |
to find all of the coefficients and extract the secret (constant value). | |
""" | |
coefficients = [secrets.randbits(field_bits) for _ in range(degree - 1)] | |
coefficients.append(constant) | |
return coefficients | |
def test(key_bits: int, share_amt: int, threshold_amt: int, field_bits: int = 64): | |
secret_key = secrets.randbits(key_bits) | |
shares = Shamir.make_shares(secret_key, share_amt, threshold_amt, field_bits=field_bits) | |
# Pick minimum amount of shares randomly from generated shares | |
pool = random.sample(shares, threshold_amt) | |
reconstructed_key = Shamir.reconstruct_key(pool) | |
if secret_key == reconstructed_key: | |
print(f"Success, key was reconstructed from {threshold_amt}/{share_amt} shares") | |
else: | |
formatted_shares = "[\n" + ",\n".join(" " * 4 + str(share) for share in pool) + "\n]" | |
raise RuntimeError( | |
"Fail, algorithm wasn't able to properly reconstruct keys\n" | |
f"Secret key: {secret_key}\n" | |
f"Reconstructed key: {reconstructed_key}\n" | |
f"Shares: {formatted_shares}\n\n" | |
f"This most likely happened due to low floating point decimal precision ({getcontext().prec})" | |
) | |
def test_precision( | |
key_bits: int, | |
share_amt: int, | |
threshold_amt: int, | |
field_bits: int = 64, | |
stop_after: int = 10_000 | |
): | |
""" | |
Try to determine the optimal decimal precision value for | |
Shamir Secret Sharing. | |
`key_bits`, `share_amt` and `threshold_amt` variables are the same | |
as parameters for the `test` function. | |
`stop_after` parameter guides how many consecutive iterations without | |
failure should we perform, the higher this number is, the more confident | |
we can be about our minimum precission guess, but it also increases the time. | |
""" | |
old_precision = getcontext().prec | |
# Start from 1 and make our way up | |
p = 1 | |
# We don't want to clutter STDOUT, send output from `test` function | |
# into this object instead | |
string_out = io.StringIO() | |
# Keep track of how many consecutive iterations without failure ocurred | |
consecutive_successes = 0 | |
while True: | |
getcontext().prec = p | |
try: | |
with redirect_stdout(string_out): | |
test(key_bits, share_amt, threshold_amt, field_bits) | |
except RuntimeError: | |
p += 1 | |
consecutive_successes = 0 | |
else: | |
consecutive_successes += 1 | |
if consecutive_successes > stop_after: | |
break | |
print(f"For given parameters, decimal precision should be set to at least {p}.") | |
getcontext().prec = old_precision |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment