Skip to content

Instantly share code, notes, and snippets.

@vkobel
Last active August 16, 2024 14:27
Show Gist options
  • Save vkobel/10b9ad23b1bb2c6c73e3aa6934f0858b to your computer and use it in GitHub Desktop.
Save vkobel/10b9ad23b1bb2c6c73e3aa6934f0858b to your computer and use it in GitHub Desktop.
Toy ECDSA implementation (no deps) -- education purposes only, do NOT use in prod
import hashlib
import secrets
# Parameters for the secp256k1 curve
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
a = 0
b = 7
Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return f"Point({self.x}, {self.y})"
def is_on_curve(self):
return (self.y ** 2 - self.x ** 3 - a * self.x - b) % p == 0
@staticmethod
def add(P, Q):
if P is None:
return Q
if Q is None:
return P
if P.x == Q.x and P.y != Q.y:
return None # Point at infinity
if P == Q:
# Point doubling
s = (3 * P.x ** 2 + a) * pow(2 * P.y, p - 2, p) % p
else:
# General case of addition
s = (Q.y - P.y) * pow(Q.x - P.x, p - 2, p) % p
x = (s ** 2 - P.x - Q.x) % p
y = (s * (P.x - x) - P.y) % p
return Point(x, y)
@staticmethod
def scalar_multiply(k, P):
Q = None # Point at infinity
while k > 0:
if k & 1:
Q = Point.add(Q, P)
P = Point.add(P, P)
k >>= 1
return Q
class ECDSA:
@staticmethod
def hash_message(message):
"""Hashes a message using SHA-256 and returns the hex value as an integer."""
message_bytes = message.encode('utf-8')
hash_bytes = hashlib.sha256(message_bytes).digest()
return int.from_bytes(hash_bytes, byteorder='big')
@staticmethod
def sign_message(private_key, message):
z = ECDSA.hash_message(message)
r, s = 0, 0
while r == 0 or s == 0:
k = secrets.randbelow(n - 1) + 1 # Using cryptographically secure random number
R = Point.scalar_multiply(k, G)
r = R.x % n
s = ((z + r * private_key) * pow(k, n - 2, n)) % n
return (r, s)
@staticmethod
def verify_signature(public_key, message, signature):
r, s = signature
if not (1 <= r < n and 1 <= s < n):
return False # Invalid signature range
z = ECDSA.hash_message(message)
w = pow(s, n - 2, n)
u1 = (z * w) % n
u2 = (r * w) % n
P = Point.add(Point.scalar_multiply(u1, G), Point.scalar_multiply(u2, public_key))
if P is None or not P.is_on_curve(): # Ensure P is on curve
return False
return (P.x % n) == r
# For this we're using libraries, just for compatibility
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature, decode_dss_signature
import base64
class Encoding:
@staticmethod
def private_key_to_pem(private_key):
"""Encodes the private key in PKCS#8 PEM format."""
# Convert the integer private key to an EC private key object
ec_private_key = ec.derive_private_key(private_key, ec.SECP256K1())
# Serialize the private key to PEM format using PKCS#8
pem = ec_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
return pem.decode('utf-8')
@staticmethod
def public_key_to_pem(public_key):
"""Encodes the public key in PEM format."""
# Convert the public key point to an EC public key object
ec_public_key = ec.EllipticCurvePublicNumbers(public_key.x, public_key.y, ec.SECP256K1()).public_key()
# Serialize the public key to PEM format
pem = ec_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem.decode('utf-8')
@staticmethod
def signature_to_base64(signature):
"""Encodes the ECDSA signature in DER format and then Base64."""
r, s = signature
der_signature = encode_dss_signature(r, s)
return base64.b64encode(der_signature).decode('utf-8')
@staticmethod
def signature_from_base64(base64_signature):
der_signature = base64.b64decode(base64_signature)
r, s = decode_dss_signature(der_signature)
return (r, s)
# Initialize the base point G
G = Point(Gx, Gy)
assert G.is_on_curve(), "Base point G is not on the curve"
# Example usage
private_key = secrets.randbelow(n - 1) + 1
public_key = Point.scalar_multiply(private_key, G)
message = "Hello, ECDSA!"
signature = ECDSA.sign_message(private_key, message)
is_valid = ECDSA.verify_signature(public_key, message, signature)
print(f"Signature valid: {is_valid}")
# Encode keys and signature
private_key_pem = Encoding.private_key_to_pem(private_key)
public_key_pem = Encoding.public_key_to_pem(public_key)
signature_base64 = Encoding.signature_to_base64(signature)
print("Private Key in PKCS#8 PEM format:")
print(private_key_pem)
print("Public Key in PEM format:")
print(public_key_pem)
print("Signature in Base64 format:")
print(signature_base64)
import unittest
import secrets
import time
class TestECDSA(unittest.TestCase):
def setUp(self):
# Initialize curve parameters and keys for testing
self.p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
self.a = 0
self.b = 7
self.Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
self.Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
self.n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
self.G = Point(self.Gx, self.Gy)
self.private_key = secrets.randbelow(self.n - 1) + 1
self.public_key = Point.scalar_multiply(self.private_key, self.G)
def test_point_on_curve(self):
# Test if the base point is on the curve
self.assertTrue(self.G.is_on_curve())
# Test if a random point is not on the curve
invalid_point = Point(1, 1)
self.assertFalse(invalid_point.is_on_curve())
def test_point_addition(self):
# Test point addition with point at infinity
P = Point(2, 3)
self.assertEqual(Point.add(P, None), P)
self.assertEqual(Point.add(None, P), P)
# Test point addition resulting in point at infinity
Q = Point(P.x, -P.y % self.p)
self.assertIsNone(Point.add(P, Q))
# Test point doubling
result = Point.add(P, P)
s = (3 * P.x**2 + self.a) * pow(2 * P.y, self.p-2, self.p) % self.p
expected_x = (s**2 - 2 * P.x) % self.p
expected_y = (s * (P.x - expected_x) - P.y) % self.p
self.assertEqual(result.x, expected_x)
self.assertEqual(result.y, expected_y)
def test_scalar_multiplication(self):
# Test scalar multiplication with zero
self.assertIsNone(Point.scalar_multiply(0, self.G))
# Test scalar multiplication with one
self.assertEqual(Point.scalar_multiply(1, self.G), self.G)
# Test scalar multiplication with the order of the group
self.assertIsNone(Point.scalar_multiply(self.n, self.G))
def test_sign_message(self):
message = "Hello"
signature = ECDSA.sign_message(self.private_key, message)
self.assertEqual(len(signature), 2)
r, s = signature
self.assertTrue(1 <= r < self.n)
self.assertTrue(1 <= s < self.n)
def test_verify_signature(self):
message = "Hello"
signature = ECDSA.sign_message(self.private_key, message)
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature))
def test_invalid_signature(self):
message = "Hello"
invalid_signature = (0, 0)
self.assertFalse(ECDSA.verify_signature(self.public_key, message, invalid_signature))
def test_invalid_public_key(self):
invalid_public_key = Point(0, 0) # Invalid public key
message = "Hello"
signature = ECDSA.sign_message(self.private_key, message)
self.assertFalse(ECDSA.verify_signature(invalid_public_key, message, signature))
def test_boundary_conditions(self):
# Test signing and verifying with boundary values of r and s
message = "Boundary test"
signature = (1, 1)
self.assertFalse(ECDSA.verify_signature(self.public_key, message, signature))
signature = (self.n - 1, self.n - 1)
self.assertFalse(ECDSA.verify_signature(self.public_key, message, signature))
def test_boundary_private_keys(self):
min_key = 1
max_key = self.n - 1
min_pub_key = Point.scalar_multiply(min_key, self.G)
max_pub_key = Point.scalar_multiply(max_key, self.G)
self.assertTrue(min_pub_key.is_on_curve())
self.assertTrue(max_pub_key.is_on_curve())
def test_message_variations(self):
messages = ["", "a", "This is a longer message.", "🚀🌟✨"]
for message in messages:
signature = ECDSA.sign_message(self.private_key, message)
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature))
def test_character_encodings(self):
message = "Hello, ECDSA!"
utf8_signature = ECDSA.sign_message(self.private_key, message.encode('utf-8').decode('utf-8'))
ascii_signature = ECDSA.sign_message(self.private_key, message.encode('ascii').decode('ascii'))
self.assertTrue(ECDSA.verify_signature(self.public_key, message, utf8_signature))
self.assertTrue(ECDSA.verify_signature(self.public_key, message, ascii_signature))
def test_consistency(self):
message = "Consistency test"
signature1 = ECDSA.sign_message(self.private_key, message)
signature2 = ECDSA.sign_message(self.private_key, message)
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature1))
self.assertTrue(ECDSA.verify_signature(self.public_key, message, signature2))
if __name__ == '__main__':
unittest.main(argv=[''], exit=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment