Last active
August 16, 2024 14:27
-
-
Save vkobel/10b9ad23b1bb2c6c73e3aa6934f0858b to your computer and use it in GitHub Desktop.
Toy ECDSA implementation (no deps) -- education purposes only, do NOT use in prod
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
import secrets | |
# Parameters for the secp256k1 curve | |
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F | |
a = 0 | |
b = 7 | |
Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 | |
Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 | |
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 | |
class Point: | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def __repr__(self): | |
return f"Point({self.x}, {self.y})" | |
def is_on_curve(self): | |
return (self.y ** 2 - self.x ** 3 - a * self.x - b) % p == 0 | |
@staticmethod | |
def add(P, Q): | |
if P is None: | |
return Q | |
if Q is None: | |
return P | |
if P.x == Q.x and P.y != Q.y: | |
return None # Point at infinity | |
if P == Q: | |
# Point doubling | |
s = (3 * P.x ** 2 + a) * pow(2 * P.y, p - 2, p) % p | |
else: | |
# General case of addition | |
s = (Q.y - P.y) * pow(Q.x - P.x, p - 2, p) % p | |
x = (s ** 2 - P.x - Q.x) % p | |
y = (s * (P.x - x) - P.y) % p | |
return Point(x, y) | |
@staticmethod | |
def scalar_multiply(k, P): | |
Q = None # Point at infinity | |
while k > 0: | |
if k & 1: | |
Q = Point.add(Q, P) | |
P = Point.add(P, P) | |
k >>= 1 | |
return Q | |
class ECDSA: | |
@staticmethod | |
def hash_message(message): | |
"""Hashes a message using SHA-256 and returns the hex value as an integer.""" | |
message_bytes = message.encode('utf-8') | |
hash_bytes = hashlib.sha256(message_bytes).digest() | |
return int.from_bytes(hash_bytes, byteorder='big') | |
@staticmethod | |
def sign_message(private_key, message): | |
z = ECDSA.hash_message(message) | |
r, s = 0, 0 | |
while r == 0 or s == 0: | |
k = secrets.randbelow(n - 1) + 1 # Using cryptographically secure random number | |
R = Point.scalar_multiply(k, G) | |
r = R.x % n | |
s = ((z + r * private_key) * pow(k, n - 2, n)) % n | |
return (r, s) | |
@staticmethod | |
def verify_signature(public_key, message, signature): | |
r, s = signature | |
if not (1 <= r < n and 1 <= s < n): | |
return False # Invalid signature range | |
z = ECDSA.hash_message(message) | |
w = pow(s, n - 2, n) | |
u1 = (z * w) % n | |
u2 = (r * w) % n | |
P = Point.add(Point.scalar_multiply(u1, G), Point.scalar_multiply(u2, public_key)) | |
if P is None or not P.is_on_curve(): # Ensure P is on curve | |
return False | |
return (P.x % n) == r |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For this we're using libraries, just for compatibility | |
from cryptography.hazmat.primitives.asymmetric import ec | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature, decode_dss_signature | |
import base64 | |
class Encoding: | |
@staticmethod | |
def private_key_to_pem(private_key): | |
"""Encodes the private key in PKCS#8 PEM format.""" | |
# Convert the integer private key to an EC private key object | |
ec_private_key = ec.derive_private_key(private_key, ec.SECP256K1()) | |
# Serialize the private key to PEM format using PKCS#8 | |
pem = ec_private_key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption() | |
) | |
return pem.decode('utf-8') | |
@staticmethod | |
def public_key_to_pem(public_key): | |
"""Encodes the public key in PEM format.""" | |
# Convert the public key point to an EC public key object | |
ec_public_key = ec.EllipticCurvePublicNumbers(public_key.x, public_key.y, ec.SECP256K1()).public_key() | |
# Serialize the public key to PEM format | |
pem = ec_public_key.public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo | |
) | |
return pem.decode('utf-8') | |
@staticmethod | |
def signature_to_base64(signature): | |
"""Encodes the ECDSA signature in DER format and then Base64.""" | |
r, s = signature | |
der_signature = encode_dss_signature(r, s) | |
return base64.b64encode(der_signature).decode('utf-8') | |
@staticmethod | |
def signature_from_base64(base64_signature): | |
der_signature = base64.b64decode(base64_signature) | |
r, s = decode_dss_signature(der_signature) | |
return (r, s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initialize the base point G | |
G = Point(Gx, Gy) | |
assert G.is_on_curve(), "Base point G is not on the curve" | |
# Example usage | |
private_key = secrets.randbelow(n - 1) + 1 | |
public_key = Point.scalar_multiply(private_key, G) | |
message = "Hello, ECDSA!" | |
signature = ECDSA.sign_message(private_key, message) | |
is_valid = ECDSA.verify_signature(public_key, message, signature) | |
print(f"Signature valid: {is_valid}") | |
# Encode keys and signature | |
private_key_pem = Encoding.private_key_to_pem(private_key) | |
public_key_pem = Encoding.public_key_to_pem(public_key) | |
signature_base64 = Encoding.signature_to_base64(signature) | |
print("Private Key in PKCS#8 PEM format:") | |
print(private_key_pem) | |
print("Public Key in PEM format:") | |
print(public_key_pem) | |
print("Signature in Base64 format:") | |
print(signature_base64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import secrets | |
import time | |
class TestECDSA(unittest.TestCase): | |
def setUp(self): | |
# Initialize curve parameters and keys for testing | |
self.p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F | |
self.a = 0 | |
self.b = 7 | |
self.Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 | |
self.Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 | |
self.n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 | |
self.G = Point(self.Gx, self.Gy) | |
self.private_key = secrets.randbelow(self.n - 1) + 1 | |
self.public_key = Point.scalar_multiply(self.private_key, self.G) | |
def test_point_on_curve(self): | |
# Test if the base point is on the curve | |
self.assertTrue(self.G.is_on_curve()) | |
# Test if a random point is not on the curve | |
invalid_point = Point(1, 1) | |
self.assertFalse(invalid_point.is_on_curve()) | |
def test_point_addition(self): | |
# Test point addition with point at infinity | |
P = Point(2, 3) | |
self.assertEqual(Point.add(P, None), P) | |
self.assertEqual(Point.add(None, P), P) | |
# Test point addition resulting in point at infinity | |
Q = Point(P.x, -P.y % self.p) | |
self.assertIsNone(Point.add(P, Q)) | |
# Test point doubling | |
result = Point.add(P, P) | |
s = (3 * P.x**2 + self.a) * pow(2 * P.y, self.p-2, self.p) % self.p | |
expected_x = (s**2 - 2 * P.x) % self.p | |
expected_y = (s * (P.x - expected_x) - P.y) % self.p | |
self.assertEqual(result.x, expected_x) | |
self.assertEqual(result.y, expected_y) | |
def test_scalar_multiplication(self): | |
# Test scalar multiplication with zero | |
self.assertIsNone(Point.scalar_multiply(0, self.G)) | |
# Test scalar multiplication with one | |
self.assertEqual(Point.scalar_multiply(1, self.G), self.G) | |
# Test scalar multiplication with the order of the group | |
self.assertIsNone(Point.scalar_multiply(self.n, self.G)) | |
def test_sign_message(self): | |
message = "Hello" | |
signature = ECDSA.sign_message(self.private_key, message) | |
self.assertEqual(len(signature), 2) | |
r, s = signature | |
self.assertTrue(1 <= r < self.n) | |
self.assertTrue(1 <= s < self.n) | |
def test_verify_signature(self): | |
message = "Hello" | |
signature = ECDSA.sign_message(self.private_key, message) | |
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature)) | |
def test_invalid_signature(self): | |
message = "Hello" | |
invalid_signature = (0, 0) | |
self.assertFalse(ECDSA.verify_signature(self.public_key, message, invalid_signature)) | |
def test_invalid_public_key(self): | |
invalid_public_key = Point(0, 0) # Invalid public key | |
message = "Hello" | |
signature = ECDSA.sign_message(self.private_key, message) | |
self.assertFalse(ECDSA.verify_signature(invalid_public_key, message, signature)) | |
def test_boundary_conditions(self): | |
# Test signing and verifying with boundary values of r and s | |
message = "Boundary test" | |
signature = (1, 1) | |
self.assertFalse(ECDSA.verify_signature(self.public_key, message, signature)) | |
signature = (self.n - 1, self.n - 1) | |
self.assertFalse(ECDSA.verify_signature(self.public_key, message, signature)) | |
def test_boundary_private_keys(self): | |
min_key = 1 | |
max_key = self.n - 1 | |
min_pub_key = Point.scalar_multiply(min_key, self.G) | |
max_pub_key = Point.scalar_multiply(max_key, self.G) | |
self.assertTrue(min_pub_key.is_on_curve()) | |
self.assertTrue(max_pub_key.is_on_curve()) | |
def test_message_variations(self): | |
messages = ["", "a", "This is a longer message.", "🚀🌟✨"] | |
for message in messages: | |
signature = ECDSA.sign_message(self.private_key, message) | |
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature)) | |
def test_character_encodings(self): | |
message = "Hello, ECDSA!" | |
utf8_signature = ECDSA.sign_message(self.private_key, message.encode('utf-8').decode('utf-8')) | |
ascii_signature = ECDSA.sign_message(self.private_key, message.encode('ascii').decode('ascii')) | |
self.assertTrue(ECDSA.verify_signature(self.public_key, message, utf8_signature)) | |
self.assertTrue(ECDSA.verify_signature(self.public_key, message, ascii_signature)) | |
def test_consistency(self): | |
message = "Consistency test" | |
signature1 = ECDSA.sign_message(self.private_key, message) | |
signature2 = ECDSA.sign_message(self.private_key, message) | |
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature1)) | |
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature2)) | |
if __name__ == '__main__': | |
unittest.main(argv=[''], exit=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment