Created
September 18, 2010 04:17
-
-
Save mikeboers/585345 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import random | |
import math | |
import hashlib | |
import hmac | |
def str_to_int(input): | |
out = 0 | |
for c in input: | |
out <<= 8 | |
out += ord(c) | |
return out | |
def int_to_str(input): | |
out = [] | |
while input > 0: | |
out.append(chr(input % 256)) | |
input >>= 8 | |
return str.join('', reversed(out)) | |
def get_coefs(n, size): | |
return [random.randint(0, size) for i in xrange(n)] | |
def get_point(val, coefs, n): | |
for i, c in enumerate(coefs): | |
val += c * n ** (i + 1) | |
return (n, val) | |
def solve(points): | |
'''Solve the lagrange polynomial for the given points at x=0.''' | |
# Calculate the constant term of the Lagrange basis polynomials. The | |
# variable names are lifted straight from Wikipedia. We are calculating | |
# the numerator and denomenator seperately. | |
basis_constants = [] | |
for j, (xj, yj) in enumerate(points): | |
num = yj | |
den = 1 | |
for f, (xf, yf) in enumerate(points): | |
if j == f: | |
continue | |
num *= - xf | |
den *= xj - xf | |
basis_constants.append((num, den)) | |
# Add all the fractions together. | |
total_num = 0 | |
total_den = 1 | |
for i, (num, den) in enumerate(basis_constants): | |
total_den *= den | |
for j, (_, den2) in enumerate(basis_constants): | |
if i == j: | |
continue | |
num *= den2 | |
total_num += num | |
return total_num // total_den | |
def split(secret, threshold, num=None, modlen=None): | |
num = num or threshold | |
assert num >= threshold | |
modlen = modlen or 8 * len(secret) | |
mod = 2 ** modlen | |
secret = str_to_int(secret) | |
assert mod > secret | |
coefs = get_coefs(threshold - 1, mod) | |
points = [get_point(secret, coefs, i + 1) for i in xrange(num)] | |
points = [(x, y % mod) for x, y in points] | |
return points, modlen | |
def combine(points, modlen): | |
mod = 2 ** modlen | |
secret = int_to_str(solve(points) % mod) | |
return secret | |
def pad(msg, salt_length=None, hash_length=None, hash_func=None): | |
hash_func = hash_func or hashlib.sha1 | |
salt_length = salt_length or hash_func().block_size | |
hash_length = hash_length or hash_func().digest_size | |
salt = os.urandom(salt_length) | |
return '\x01' + msg + salt + hmac.new(salt, msg, hash_func).digest()[:hash_length] | |
def unpad(msg, salt_length=None, hash_length=None, hash_func=None): | |
hash_func = hash_func or hashlib.sha1 | |
salt_length = salt_length or hash_func().block_size | |
hash_length = hash_length or hash_func().digest_size | |
msg = msg[1:] | |
digest = msg[-(hash_length):] | |
salt = msg[-(hash_length + salt_length):-(hash_length)] | |
msg = msg[:-(hash_length + salt_length)] | |
assert hmac.new(salt, msg, hash_func).digest()[:hash_length] == digest | |
return msg | |
if __name__ == '__main__': | |
import os | |
for i in xrange(4): | |
original = 'this is my message' | |
padded = pad(original, 8) | |
points, modlen = split(padded, 3) | |
print padded.encode('hex') | |
for x, y in points: | |
print '\t%d-%x' % (x, y) | |
recovered = combine(points, modlen) | |
recovered = unpad(recovered, 8) | |
if recovered != original: | |
print original.encode('hex') | |
print len(points) | |
print recovered.encode('hex') | |
exit() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment