Skip to content

Instantly share code, notes, and snippets.

@ItsDrike
Created March 15, 2021 18:20
Show Gist options
  • Save ItsDrike/bde2f77efa0efedb070a0dffd0a28e24 to your computer and use it in GitHub Desktop.
Save ItsDrike/bde2f77efa0efedb070a0dffd0a28e24 to your computer and use it in GitHub Desktop.
Shamir Secret Sharing
import io
from contextlib import redirect_stdout
import random
import secrets
import decimal
from decimal import Decimal, InvalidOperation, getcontext
# Set decimal precision high enough to compute the interpolation
# without any loss in floating point operations
getcontext().prec = 500
class Shamir:
@classmethod
def make_shares(
cls,
secret_key: int,
share_amt: int,
threshold_amt: int,
field_bits: int = 64
) -> list[tuple[int, int]]:
"""
This generates `share_amt` of keys, which individually can't
be meaningfully used to reconstruct secret_key, and aren't directly
connected with it, but when `threshold_amt` of them are combined together,
the true `secret_key` can be reconstructed.
This works using Shamir Secret Sharing (SSS) algorithm, which generates
a secret polynomial with a constant value of our `secret_key`, we can then
recover this key from n or more points belonging to this polynomial.
`field_bits` is the amount of bits used for random generation of the
coefficients for given polynomial, 64bit key values are usually enough here.
Note that if this value is too small, we wont get far from the original value
on the graph, and each share will contain the `y` value that may be close to
the original key value, for that reason, we shouldn't make this value too small.
A list of points (x, y) belonging to our polynomial is returned, which
represent the individual shares, from which we can easily reconstruct
the coefficients and determine the key constant.
"""
coefficients = cls._make_coefficients(
degree=threshold_amt,
constant=secret_key,
field_bits=field_bits
)
# We can't generate n shares with less than n possibilities for points
if 2**field_bits < share_amt:
raise ValueError(
f"Attempted to generate {share_amt} of shares, \n"
f"with only {2**field_bits} ({field_bits=}) possibilities for points"
)
# Generate the shares
# share is a combination of (x, y) points on graph of our polynomial
shares = []
used = set()
for _ in range(share_amt):
# Always use different points, we don't want same shares
while True:
x = secrets.randbits(field_bits)
if x not in used:
used.add(x)
break
y = cls._evaluate_polynomial(x, coefficients)
shares.append((x, y))
return shares
@classmethod
def reconstruct_key(cls, shares: list[tuple[int, int]]) -> int:
"""
Reconstruct the secret key from individual `shares`.
This works by determining the coefficients of a polynomial,
that contains the points defined in the `shares` list, with each
share being a single (x, y) point on the graph of our polynomial
where the constant value of this polynomial is equal to our key.
This method of obtaining the polynomial from n given points is
called lagrange interpolation.
"""
try:
key_decimal = cls._interpolate(shares, 0)
except ValueError:
raise ValueError("All shares must be unique, duplicate share found.")
try:
return int(round(key_decimal, 0))
except InvalidOperation:
raise RuntimeError(
"Precision for decimal objects isn't high enough to compute exact interpolation"
)
@staticmethod
def _interpolate(points: list[tuple[int, int]], x: int) -> Decimal:
"""
Given the list of `points` belonging to certain polynomial function,
evaluate this function at given `x` and return the corresponding y value.
"""
result = Decimal(0)
for i, (xi, yi) in enumerate(points):
i, xi, yi = Decimal(i), Decimal(xi), Decimal(yi)
# Compute individual terms of the lagrange's interpolation formula
term = yi
for j, (xj, _yj) in enumerate(points):
i, xi, _yj = Decimal(i), Decimal(xi), Decimal(_yj)
if i == j:
continue
try:
term *= (x - xj) / (xi - xj)
except decimal.DivisionByZero:
raise ValueError("All passed points must be unique")
result += term
return result
@staticmethod
def _evaluate_polynomial(x: int, coefficients: list[int]) -> int:
"""
This generates a single point on the graph of given polynomial
in `x`. The polynomial is given by the list of `coefficients`.
"""
point = 0
# Loop through reversed list, so that indices from enumerate match the
# actual coefficient indices
for coefficient_index, coefficient_value in enumerate(coefficients[::-1]):
point += x ** coefficient_index * coefficient_value
return point
@staticmethod
def _make_coefficients(degree: int, constant: int, field_bits: int) -> list[int]:
"""
Generate a list of coefficients for a polynomial with degree of `degree` - 1,
whose constant is `constant`. The maximum value for each of these coefficients
is `field_size`.
For example with a 3rd degree coefficient like this:
3x^3 + 4x^2 + 18x + 554
554 is the secret, and the polynomial degree + 1 is how many points
are needed, to recover this secret. With 3rd degree polynomial,
minimum of any 4 points, belonging to this polynomial are needed
to find all of the coefficients and extract the secret (constant value).
"""
coefficients = [secrets.randbits(field_bits) for _ in range(degree - 1)]
coefficients.append(constant)
return coefficients
def test(key_bits: int, share_amt: int, threshold_amt: int, field_bits: int = 64):
secret_key = secrets.randbits(key_bits)
shares = Shamir.make_shares(secret_key, share_amt, threshold_amt, field_bits=field_bits)
# Pick minimum amount of shares randomly from generated shares
pool = random.sample(shares, threshold_amt)
reconstructed_key = Shamir.reconstruct_key(pool)
if secret_key == reconstructed_key:
print(f"Success, key was reconstructed from {threshold_amt}/{share_amt} shares")
else:
formatted_shares = "[\n" + ",\n".join(" " * 4 + str(share) for share in pool) + "\n]"
raise RuntimeError(
"Fail, algorithm wasn't able to properly reconstruct keys\n"
f"Secret key: {secret_key}\n"
f"Reconstructed key: {reconstructed_key}\n"
f"Shares: {formatted_shares}\n\n"
f"This most likely happened due to low floating point decimal precision ({getcontext().prec})"
)
def test_precision(
key_bits: int,
share_amt: int,
threshold_amt: int,
field_bits: int = 64,
stop_after: int = 10_000
):
"""
Try to determine the optimal decimal precision value for
Shamir Secret Sharing.
`key_bits`, `share_amt` and `threshold_amt` variables are the same
as parameters for the `test` function.
`stop_after` parameter guides how many consecutive iterations without
failure should we perform, the higher this number is, the more confident
we can be about our minimum precission guess, but it also increases the time.
"""
old_precision = getcontext().prec
# Start from 1 and make our way up
p = 1
# We don't want to clutter STDOUT, send output from `test` function
# into this object instead
string_out = io.StringIO()
# Keep track of how many consecutive iterations without failure ocurred
consecutive_successes = 0
while True:
getcontext().prec = p
try:
with redirect_stdout(string_out):
test(key_bits, share_amt, threshold_amt, field_bits)
except RuntimeError:
p += 1
consecutive_successes = 0
else:
consecutive_successes += 1
if consecutive_successes > stop_after:
break
print(f"For given parameters, decimal precision should be set to at least {p}.")
getcontext().prec = old_precision
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment