Last active
May 11, 2018 11:25
-
-
Save th0rex/e5c1010a6e203d5afd17f14f23c40b28 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def eea(r0, r1): | |
if r0 <= r1: | |
r0, r1 = r1, r0 | |
r = [r0, r1, 0] | |
s = [1, 0, 0] | |
t = [0, 1, 0] | |
i = 2 | |
while r[(i + 2) % 3] != 0: | |
r[i % 3] = r[(i + 1) % 3] % r[(i + 2) % 3] | |
q = (r[(i + 1) % 3] - r[i % 3]) / r[(i + 2) % 3] | |
s[i % 3] = s[(i + 1) % 3] - q * s[(i + 2) % 3] | |
t[i % 3] = t[(i + 1) % 3] - q * t[(i + 2) % 3] | |
i += 1 | |
return (s[(i + 1) % 3] % r1, t[(i + 1) % 3] % r0, r[(i + 1) % 3]) | |
class ECInf: | |
def __add__(self, other): | |
return other | |
def __eq__(self, other): | |
return type(other) is ECInf | |
def __repr__(self): | |
return "(inf)" | |
class ECPoint: | |
def __init__(self, ec, x, y): | |
self.ec = ec | |
self.x = x | |
self.y = y | |
def __add__(self, other): | |
if type(other) is ECInf: | |
return self | |
if self.y == -other.y % self.ec.p and self.x == other.x: | |
return ECInf() | |
s = int(self.get_s(other)) | |
# print("s &= {} \\bmod {}\\\\".format(s, self.ec.p)) | |
x = int((s ** 2 - self.x - other.x) % self.ec.p) | |
y = int((s * (self.x - x) - self.y) % self.ec.p) | |
# print("x_3 &= {}^2 - {} - {} \\bmod {}\\\\\n&= {} \\\\".format(s, self.x, other.x, self.ec.p, x)) | |
# print("y_3 &= {}*\\left({} - {}\\right) - {} \\bmod {}\\\\\n&= {} \\\\".format(s, self.x, x, self.y, self.ec.p, y)) | |
return ECPoint(self.ec, x, y) | |
def __eq__(self, other): | |
if type(other) is ECInf: | |
return False | |
return self.ec == other.ec and self.x == other.x and self.y == other.y | |
def __ne__(self, other): | |
return not self == other | |
def __repr__(self): | |
return "({}, {})".format(self.x, self.y) | |
def get_s(self, other): | |
if self == other: | |
_, inv, _ = eea((2 * self.y) % self.ec.p, self.ec.p) | |
return (3 * self.x ** 2 + self.ec.a) * inv % self.ec.p | |
else: | |
_, inv, _ = eea((other.x - self.x) % self.ec.p, self.ec.p) | |
return (other.y - self.y) * inv % self.ec.p | |
def inverse(self): | |
return ECPoint(self.ec, self.x, -self.y % self.ec.p) | |
class EC: | |
def __init__(self, a, b, p): | |
if (4 * pow(a, 3, p) + 27 * pow(b, 2, p)) % p == 0: | |
raise ValueError("4*a^3 + 27*b^2 = 0") | |
self.a = a | |
self.b = b | |
self.p = p | |
def __repr__(self): | |
return "y^2 = x^3 + {} * x + {} mod {}".format(self.a, self.b, self.p) | |
def __eq__(self, other): | |
return self.a == other.a and self.b == other.b and self.p == other.p | |
def __ne__(self, other): | |
return not self == other | |
def on_curve(self, p): | |
if p is ECInf: | |
return True | |
return (p.y ** 2) % self.p == (p.x ** 3 + self.a * p.x + self.b) % self.p | |
def naf(e): | |
x, i, a = e, 0, [] | |
while x >= 1: | |
if x % 2 == 1: | |
a.append(2 - int((x % 4))) | |
x = x - a[i] | |
else: | |
a.append(0) | |
x = x / 2 | |
i += 1 | |
return a[::-1] | |
def naf_ec(e, p): | |
e, q, i = naf(e), p, 0 | |
if e[0] != 1: | |
raise ValueError("wtf?") | |
for x in e[1:]: | |
i += 1 | |
q = q + q | |
if x == 1: | |
i += 1 | |
q = q + p | |
i += 1 | |
if x == -1: | |
i += 1 | |
q = q + p.inverse() | |
print("Took {} steps".format(i)) | |
return q | |
def check_naf(e): | |
import functools | |
return functools.reduce(lambda x,y: x+y, [x * 2**i for i,x in enumerate(naf(e)[::-1])]) | |
def d_add(p, e): | |
i = 0 | |
for x in range(32): | |
if (e & (1 << (32 - x))) == 0: | |
i += 1 | |
else: | |
break | |
acc, s = p, 0 | |
while i < 32: | |
# print("\\text{{{} + {}:}}\\\\".format(acc, acc)) | |
acc = acc + acc | |
s += 1 | |
if (e & (1 << (32 - 1 - i))): | |
# print("\\text{{{} + {}:}}\\\\".format(acc, p)) | |
acc = acc + p | |
s += 1 | |
i += 1 | |
print("Took {} steps".format(s)) | |
return acc | |
if __name__ == "__main__": | |
a, b, p = map(int, input("a b p: ").split(' ')) | |
x1, y1 = map(int, input("x1 y1: ").split(' ')) | |
x2, y2 = map(int, input("x2 y2: ").split(' ')) | |
c = EC(a, b, p) | |
r = ECPoint(c, x1, y1) + ECPoint(c, x2, y2) | |
print("Result: ", r) | |
print("Point on curve: ", c.on_curve(r)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment