Skip to content

Instantly share code, notes, and snippets.

@tarassh
Last active July 13, 2025 20:14
Show Gist options
  • Save tarassh/43aec62ad64e73f574d131cef2e92b10 to your computer and use it in GitHub Desktop.
Save tarassh/43aec62ad64e73f574d131cef2e92b10 to your computer and use it in GitHub Desktop.
Oblivious Transfer
#!/usr/bin/env python3
"""
P-256 Elliptic Curve Cryptography Utilities
This module provides a centralized implementation of P-256 elliptic curve
operations, constants, and utilities used across the MPC protocol implementations.
It eliminates code duplication and provides a consistent interface for all
ECC-related operations.
This module includes:
- P-256 curve parameters and constants
- Point representation and validation
- Elliptic curve point arithmetic (addition, doubling, scalar multiplication)
- Modular arithmetic utilities
- Point generation and validation functions
Used by:
- mta_protocol.py: Field modulus and modular arithmetic
- ectf_protocol.py: Complete elliptic curve operations and point validation
- oblivious_transfer.py: Curve parameters and legacy point operations
All implementations follow the NIST P-256 / secp256r1 standard specifications.
This centralized module was created by refactoring duplicated ECC code from
multiple protocol implementations to improve maintainability and consistency.
"""
import random
# P-256 elliptic curve parameters (NIST P-256 / secp256r1)
# These are the standardized parameters for the P-256 elliptic curve
# Curve equation: y² = x³ - 3x + b (mod p)
# Prime field modulus - same value used for both field operations and order
P256_P = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF
# Curve parameters
P256_A = -3 # Curve parameter a (equivalent to p-3 mod p for efficiency)
P256_B = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B
# Generator point coordinates
P256_GX = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296
P256_GY = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5
# Curve order (number of points on the curve)
P256_N = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551
# For compatibility with existing code that uses MODULUS_Q
MODULUS_Q = P256_P
def mod_q(x):
"""
Modular reduction modulo the P-256 field modulus.
This function provides compatibility with existing MtA protocol code
while centralizing the modular arithmetic operation.
Args:
x (int): Integer value to reduce
Returns:
int: x mod q, where q is the P-256 field modulus
"""
return x % MODULUS_Q
def modinv(a, m=None):
"""
Compute modular multiplicative inverse.
Computes a^(-1) mod m such that (a * a^(-1)) ≡ 1 (mod m).
Uses Python's built-in pow function with exponent -1.
Args:
a (int): The value to invert
m (int, optional): The modulus. Defaults to P256_P if not provided
Returns:
int: The modular inverse of a modulo m
Raises:
ValueError: If a and m are not coprime (gcd(a,m) != 1)
Note:
For the P-256 field, this is guaranteed to work for all non-zero values
since P256_P is prime.
"""
if m is None:
m = P256_P
return pow(a, -1, m)
class ECPoint:
"""
Represents a point on the P-256 elliptic curve.
This class provides a lightweight but robust representation of elliptic
curve points with validation and utility methods. It supports both
affine coordinates (x, y) and the point at infinity.
Attributes:
x (int or None): The x-coordinate of the point (None for point at infinity)
y (int or None): The y-coordinate of the point (None for point at infinity)
is_infinity (bool): True if this represents the point at infinity
"""
def __init__(self, x=None, y=None):
"""
Initialize an elliptic curve point.
Args:
x (int, optional): The x-coordinate. None for point at infinity
y (int, optional): The y-coordinate. None for point at infinity
Note:
If either x or y is None, the point is treated as the point at infinity.
This follows the convention that (None, None) represents the identity element.
"""
if x is None or y is None:
self.x = None
self.y = None
self.is_infinity = True
else:
self.x = x % P256_P
self.y = y % P256_P
self.is_infinity = False
def __eq__(self, other):
"""Check if two points are equal."""
if not isinstance(other, ECPoint):
return False
return (
self.x == other.x
and self.y == other.y
and self.is_infinity == other.is_infinity
)
def __repr__(self):
"""String representation of the point."""
if self.is_infinity:
return "ECPoint(∞)"
return f"ECPoint({self.x}, {self.y})"
def is_on_curve(self):
"""
Verify if this point lies on the P-256 elliptic curve.
Returns:
bool: True if the point is on the curve, False otherwise
Note:
The point at infinity is always considered to be on the curve.
"""
if self.is_infinity:
return True
return is_on_curve(self.x, self.y)
def is_on_curve(x, y):
"""
Verify if a point (x, y) lies on the P-256 elliptic curve.
The P-256 curve equation is: y² ≡ x³ - 3x + b (mod p)
where b = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B
Args:
x (int): x-coordinate of the point
y (int): y-coordinate of the point
Returns:
bool: True if the point is on the curve, False otherwise
Note:
This function checks the fundamental elliptic curve equation.
Points not satisfying this equation are invalid for ECC operations.
"""
# Reduce coordinates modulo p
x = x % P256_P
y = y % P256_P
# Compute left side: y²
left_side = (y * y) % P256_P
# Compute right side: x³ - 3x + b
x_cubed = (x * x * x) % P256_P
three_x = (3 * x) % P256_P
right_side = (x_cubed - three_x + P256_B) % P256_P
return left_side == right_side
def ec_point_add(p1, p2):
"""
Perform elliptic curve point addition on the P-256 curve.
Implements the complete elliptic curve point addition algorithm:
- Handles point at infinity cases
- Point doubling when P1 = P2
- General point addition when P1 ≠ P2
- Returns point at infinity for additive inverses
Args:
p1 (ECPoint): First point on the curve
p2 (ECPoint): Second point on the curve
Returns:
ECPoint: The sum P1 + P2 on the elliptic curve
Raises:
ValueError: If points are not on the curve
Note:
This is the reference implementation for elliptic curve point addition.
It handles all edge cases and special situations correctly.
"""
# Handle ECPoint objects
if not isinstance(p1, ECPoint):
if hasattr(p1, "x") and hasattr(p1, "y"):
p1 = ECPoint(p1.x, p1.y)
else:
raise ValueError("p1 must be an ECPoint or have x,y attributes")
if not isinstance(p2, ECPoint):
if hasattr(p2, "x") and hasattr(p2, "y"):
p2 = ECPoint(p2.x, p2.y)
else:
raise ValueError("p2 must be an ECPoint or have x,y attributes")
# Handle point at infinity cases
if p1.is_infinity:
return p2
if p2.is_infinity:
return p1
# Verify both points are on the curve
if not p1.is_on_curve():
raise ValueError(f"Point P1({p1.x}, {p1.y}) is not on the P-256 curve")
if not p2.is_on_curve():
raise ValueError(f"Point P2({p2.x}, {p2.y}) is not on the P-256 curve")
# Handle case where points have same x-coordinate
if p1.x == p2.x:
if p1.y == p2.y:
# Point doubling: P + P = 2P
# λ = (3x₁² + a) / (2y₁) where a = -3 for P-256
numerator = (3 * p1.x * p1.x + P256_A) % P256_P
denominator = (2 * p1.y) % P256_P
lambda_slope = (numerator * modinv(denominator, P256_P)) % P256_P
else:
# Points are additive inverses: P + (-P) = O (point at infinity)
return ECPoint() # Point at infinity
else:
# General point addition: P₁ + P₂ where P₁ ≠ P₂
# λ = (y₂ - y₁) / (x₂ - x₁)
numerator = (p2.y - p1.y) % P256_P
denominator = (p2.x - p1.x) % P256_P
lambda_slope = (numerator * modinv(denominator, P256_P)) % P256_P
# Compute result coordinates
# x₃ = λ² - x₁ - x₂
# y₃ = λ(x₁ - x₃) - y₁
x3 = (lambda_slope * lambda_slope - p1.x - p2.x) % P256_P
y3 = (lambda_slope * (p1.x - x3) - p1.y) % P256_P
return ECPoint(x3, y3)
def ec_point_double(p):
"""
Perform elliptic curve point doubling: compute 2P.
This is equivalent to ec_point_add(p, p) but more efficient.
Args:
p (ECPoint): Point to double
Returns:
ECPoint: The point 2P
"""
return ec_point_add(p, p)
def ec_scalar_mult(k, p):
"""
Perform elliptic curve scalar multiplication: compute k*P.
Uses the double-and-add algorithm for efficient computation.
Args:
k (int): The scalar multiplier
p (ECPoint): The point to multiply
Returns:
ECPoint: The point k*P
Note:
This implementation uses binary expansion of k and is
suitable for cryptographic applications.
"""
if not isinstance(p, ECPoint):
p = ECPoint(p.x, p.y)
if k == 0:
return ECPoint() # Point at infinity
if k < 0:
# For negative k, compute k*P = (-k)*(-P)
k = -k
p = ECPoint(p.x, (-p.y) % P256_P) # Point negation
result = ECPoint() # Point at infinity (identity element)
addend = p
while k:
if k & 1:
result = ec_point_add(result, addend)
addend = ec_point_double(addend)
k >>= 1
return result
def generate_valid_point():
"""
Generate a random valid point on the P-256 curve.
This function generates random x-coordinates and attempts to find
corresponding y-coordinates that satisfy the curve equation.
Returns:
ECPoint: A valid point on the P-256 elliptic curve
Note:
This is a simple implementation for testing. In production,
use proper cryptographic libraries for point generation.
The function will fallback to the generator point if random
generation fails after 1000 attempts.
"""
# Try random x-coordinates until we find one with a valid y
for _ in range(1000): # Limit attempts to avoid infinite loop
x = random.randint(1, P256_P - 1)
# Compute right side of curve equation: x³ - 3x + b
x_cubed = (x * x * x) % P256_P
three_x = (3 * x) % P256_P
y_squared = (x_cubed - three_x + P256_B) % P256_P
# Check if y_squared is a quadratic residue (has a square root)
# Using Tonelli-Shanks or simple power for square root
y = pow(y_squared, (P256_P + 1) // 4, P256_P)
# Verify this is actually a square root
if (y * y) % P256_P == y_squared:
return ECPoint(x, y)
# Fallback to generator point if random generation fails
return ECPoint(P256_GX, P256_GY)
def get_generator():
"""
Get the P-256 generator point G.
Returns:
ECPoint: The standard P-256 generator point
"""
return ECPoint(P256_GX, P256_GY)
# Compatibility functions for existing code
def point_add(P, Q):
"""
Legacy compatibility function for point addition.
Converts tuple representation to ECPoint and back for compatibility
with existing code that uses (x, y) tuples.
Args:
P (tuple or None): First point as (x, y) or None for infinity
Q (tuple or None): Second point as (x, y) or None for infinity
Returns:
tuple or None: Result point as (x, y) or None for infinity
"""
if P is None:
p1 = ECPoint()
else:
p1 = ECPoint(P[0], P[1])
if Q is None:
p2 = ECPoint()
else:
p2 = ECPoint(Q[0], Q[1])
result = ec_point_add(p1, p2)
if result.is_infinity:
return None
return (result.x, result.y)
def inv_mod(x, m):
"""
Legacy compatibility function for modular inverse.
Args:
x (int): Value to invert
m (int): Modulus
Returns:
int: Modular inverse of x modulo m
"""
return modinv(x, m)
# Export commonly used constants and functions
__all__ = [
# Constants
"P256_P",
"P256_A",
"P256_B",
"P256_GX",
"P256_GY",
"P256_N",
"MODULUS_Q",
# Core classes
"ECPoint",
# Elliptic curve operations
"ec_point_add",
"ec_point_double",
"ec_scalar_mult",
# Utility functions
"is_on_curve",
"modinv",
"mod_q",
"generate_valid_point",
"get_generator",
# Legacy compatibility
"point_add",
"inv_mod",
]
#!/usr/bin/env python3
"""
Oblivious Transfer implementation using Elliptic Curve Cryptography and XOR.
This implements a simplified version of the Chou-Orlandi OT protocol:
"The Simplest Protocol for Oblivious Transfer" (2015)
https://eprint.iacr.org/2015/267.pdf
The protocol allows a sender with two messages (m0, m1) and a receiver with
a choice bit b to perform OT such that:
- Receiver learns m_b but nothing about m_(1-b)
- Sender learns nothing about the choice bit b
This implementation uses the P-256 curve parameters from the centralized ecc_utils
module for consistency with other cryptographic protocols in this project.
The OT protocol has been fully refactored to use centralized ECC operations:
- Generator point from get_generator()
- Scalar multiplication via ec_scalar_mult() with tuple compatibility wrapper
- Point validation using ECPoint.is_on_curve()
Dependencies:
- ecc_utils: Centralized P-256 curve parameters, point operations, and validation
"""
import hashlib
import secrets
import sys
import os
# Import centralized ECC utilities for P-256 curve operations
# Uses the refactored ecc_utils module for consistent curve parameters
sys.path.append(os.path.dirname(__file__))
from ecc_utils import P256_P as p, point_add, ec_scalar_mult, get_generator, ECPoint
# Get P-256 generator point G using centralized ecc_utils
# Convert to tuple format for compatibility with legacy OT protocol functions
_generator_point = get_generator()
G = (_generator_point.x, _generator_point.y)
# Note: Using centralized modinv function from ecc_utils for modular inverse operations
# This ensures consistency with other ECC operations across the project
def point_multiply(point, scalar):
"""
Multiply a point by a scalar using the centralized ec_scalar_mult function.
This is a compatibility wrapper that converts between tuple representation
(used by OT protocol) and ECPoint objects (used by ecc_utils).
Args:
point (tuple): Point as (x, y) coordinates or None for point at infinity
scalar (int): Scalar multiplier
Returns:
tuple or None: Result point as (x, y) or None for point at infinity
"""
if scalar == 0 or point is None:
return None # Point at infinity
if scalar == 1:
return point
if scalar < 0:
# Handle negative scalars by negating the point
scalar = -scalar
point = (point[0], (-point[1]) % p)
# Create ECPoint object and use centralized scalar multiplication
ec_point = ECPoint(point[0], point[1])
result_point = ec_scalar_mult(scalar, ec_point)
# Convert back to tuple format
if result_point.is_infinity:
return None
return (result_point.x, result_point.y)
def point_to_bytes(point):
"""Convert a point to bytes (uncompressed format)."""
if point is None:
return b"\x00" * 65 # Point at infinity
x, y = point
# Uncompressed format: 0x04 || x || y
return b"\x04" + x.to_bytes(32, "big") + y.to_bytes(32, "big")
def bytes_to_point(point_bytes):
"""
Convert bytes to a point and verify it lies on the P-256 curve.
This function validates that the decoded point satisfies the P-256 curve equation
using the centralized is_on_curve function from ecc_utils for consistency.
Args:
point_bytes (bytes): Point in uncompressed format (65 bytes: 0x04 || x || y)
Returns:
tuple: Point as (x, y) coordinates
Raises:
ValueError: If point encoding is invalid or point is not on curve
"""
if len(point_bytes) != 65 or point_bytes[0] != 0x04:
raise ValueError("Invalid point encoding")
x = int.from_bytes(point_bytes[1:33], "big")
y = int.from_bytes(point_bytes[33:65], "big")
try:
# Create ECPoint which automatically validates the point is on curve
test_point = ECPoint(x, y)
if not test_point.is_on_curve():
raise ValueError("Point not on curve")
except Exception as exc:
raise ValueError("Point not on curve") from exc
return (x, y)
def hash_point_to_key(point, label=b""):
"""Hash a point to derive a symmetric key of sufficient length."""
point_bytes = point_to_bytes(point)
# Use multiple rounds of hashing to get more key material
key = hashlib.sha256(point_bytes + label).digest()
# Expand key to 64 bytes to handle longer messages
key += hashlib.sha256(key + b"expand").digest()
return key
def xor_bytes(data1, data2):
"""XOR two byte arrays."""
return bytes(a ^ b for a, b in zip(data1, data2))
def pad_message(message):
"""Simple padding that ensures consistent length."""
if isinstance(message, str):
message = message.encode("utf-8")
# Simple approach: just make sure we have enough bytes for XOR
return message
def unpad_message(padded_message):
"""Simple unpadding - just strip trailing nulls."""
return padded_message.rstrip(b"\x00")
class OTSender:
"""Sender side of the Oblivious Transfer protocol."""
def __init__(self):
self.private_scalar = None
self.public_point = None
def step1_generate_public_point(self):
"""
Step 1: Generate random scalar a and compute A = a*G.
Returns the public point A to send to receiver.
"""
# Generate random scalar in [1, p-1]
self.private_scalar = secrets.randbelow(p - 1) + 1
# Compute A = a * G
self.public_point = point_multiply(G, self.private_scalar)
return point_to_bytes(self.public_point)
def step3_encrypt_messages(self, B_bytes, message0, message1):
"""
Step 3: Receive B from receiver and send encrypted messages.
For the Chou-Orlandi protocol:
- k0 = Hash(a * B) if receiver chose 0
- k1 = Hash(a * (B - A)) = Hash(a * B - a * A) if receiver chose 1
Since receiver computes:
- B = b*G if choice = 0, so a*B = a*b*G
- B = b*G + A if choice = 1, so a*B = a*b*G + a*A and a*(B-A) = a*b*G
So both cases result in the same shared secret a*b*G for the chosen message.
"""
if self.private_scalar is None:
raise ValueError("Must call step1_generate_public_point first")
# Convert B from bytes to point
B_point = bytes_to_point(B_bytes)
# Compute shared points for both possible keys
# k0: This is what receiver can compute if choice=0 (B = b*G)
shared_point_0 = point_multiply(B_point, self.private_scalar) # a*B
k0 = hash_point_to_key(shared_point_0, b"k0")
# k1: This is what receiver can compute if choice=1 (B = b*G + A)
# We need a*(B-A) = a*B - a*A
# For choice=1, receiver sends B = b*G + A
# Sender computes a*(B-A) = a*b*G
neg_A = (self.public_point[0], (-self.public_point[1]) % p)
B_minus_A = point_add(B_point, neg_A)
shared_point_1 = point_multiply(B_minus_A, self.private_scalar)
k1 = hash_point_to_key(shared_point_1, b"k1")
# Prepare messages as bytes
m0_padded = pad_message(message0)
m1_padded = pad_message(message1)
# Encrypt messages with keys of appropriate length
c0 = xor_bytes(m0_padded, k0[: len(m0_padded)])
c1 = xor_bytes(m1_padded, k1[: len(m1_padded)])
return c0, c1
class OTReceiver:
"""Receiver side of the Oblivious Transfer protocol."""
def __init__(self, choice_bit):
self.choice_bit = bool(choice_bit)
self.private_scalar = None
self.sender_public_point = None
def step2_compute_B(self, A_bytes):
"""
Step 2: Receive A from sender and compute B based on choice bit.
For Chou-Orlandi protocol:
- If choice = 0: B = b*G
- If choice = 1: B = b*G + A
This ensures that:
- If choice = 0: receiver can compute b*A = a*b*G
- If choice = 1: receiver can compute b*A = a*b*G, and sender uses a*(B-A) = a*b*G
"""
# Convert A from bytes to point
self.sender_public_point = bytes_to_point(A_bytes)
# Generate random scalar b
self.private_scalar = secrets.randbelow(p - 1) + 1
# Compute B based on choice bit
if self.choice_bit == 0:
# B = b * G
B_point = point_multiply(G, self.private_scalar)
else:
# B = b * G + A
bG = point_multiply(G, self.private_scalar)
B_point = point_add(bG, self.sender_public_point)
return point_to_bytes(B_point)
def step4_decrypt_message(self, c0, c1):
"""
Step 4: Decrypt the chosen message using the shared key.
Receiver computes b*A = a*b*G regardless of choice bit.
This matches either:
- k0 (if choice=0) when sender computes a*B = a*b*G
- k1 (if choice=1) when sender computes a*(B-A) = a*b*G
"""
if self.private_scalar is None or self.sender_public_point is None:
raise ValueError("Must call step2_compute_B first")
# Compute shared key: b * A = a*b*G (same for both choice bits)
shared_point = point_multiply(self.sender_public_point, self.private_scalar)
# Use the appropriate label based on choice bit to match sender's key derivation
if self.choice_bit == 0:
k = hash_point_to_key(shared_point, b"k0")
ciphertext = c0
else:
k = hash_point_to_key(shared_point, b"k1")
ciphertext = c1
# Use the key to decrypt
decrypted = xor_bytes(ciphertext, k[: len(ciphertext)])
# Remove any padding
return unpad_message(decrypted)
def oblivious_transfer_demo():
"""
Demonstration of the Oblivious Transfer protocol.
"""
print("Chou-Orlandi Oblivious Transfer Demo")
print("=" * 40)
# Setup
message0 = "Secret message 0: Alice's private key"
message1 = "Secret message 1: Bob's private key"
choice_bit = 1 # Receiver wants message 1
print("Sender has messages:")
print(f" m0 = '{message0}'")
print(f" m1 = '{message1}'")
print(f"Receiver choice bit: {choice_bit}")
print()
# Initialize parties
sender = OTSender()
receiver = OTReceiver(choice_bit)
# Step 1: Sender generates A
print("Step 1: Sender generates point A")
A_bytes = sender.step1_generate_public_point()
print(f"A = {A_bytes.hex()}")
print()
# Step 2: Receiver computes B
print("Step 2: Receiver computes point B based on choice bit")
B_bytes = receiver.step2_compute_B(A_bytes)
print(f"B = {B_bytes.hex()}")
print()
# Step 3: Sender encrypts messages
print("Step 3: Sender encrypts both messages")
c0, c1 = sender.step3_encrypt_messages(B_bytes, message0, message1)
print(f"c0 = {c0.hex()}")
print(f"c1 = {c1.hex()}")
print()
# Step 4: Receiver decrypts chosen message
print("Step 4: Receiver decrypts chosen message")
decrypted_message = receiver.step4_decrypt_message(c0, c1)
print(f"Decrypted message: '{decrypted_message.decode('utf-8')}'")
print()
# Verify correctness
expected_message = message1 if choice_bit else message0
protocol_success = decrypted_message.decode("utf-8") == expected_message
print(f"Protocol success: {protocol_success}")
return protocol_success
def batch_oblivious_transfer(msg_pairs, choices):
"""
Perform multiple OT operations in batch.
Args:
msg_pairs: List of (message0, message1) tuples
choices: List of choice bits for each OT
Returns:
List of decrypted messages
"""
if len(msg_pairs) != len(choices):
raise ValueError("Number of message pairs must match number of choice bits")
batch_results = []
for i, ((m0, m1), choice) in enumerate(zip(msg_pairs, choices)):
print(f"\nOT {i+1}:")
print(f" Messages: ('{m0}', '{m1}')")
print(f" Choice: {choice}")
sender = OTSender()
receiver = OTReceiver(choice)
# Execute protocol
A_bytes = sender.step1_generate_public_point()
B_bytes = receiver.step2_compute_B(A_bytes)
c0, c1 = sender.step3_encrypt_messages(B_bytes, m0, m1)
decrypted = receiver.step4_decrypt_message(c0, c1)
result = decrypted.decode("utf-8")
batch_results.append(result)
print(f" Result: '{result}'")
return batch_results
if __name__ == "__main__":
# Run basic demo
success = oblivious_transfer_demo()
if success:
print("\n" + "=" * 50)
print("Batch OT Demo")
print("=" * 50)
# Batch demo
message_pairs = [
("Alice's secret", "Bob's secret"),
("Key fragment 1", "Key fragment 2"),
("Left branch", "Right branch"),
]
choice_bits = [0, 1, 0]
ot_results = batch_oblivious_transfer(message_pairs, choice_bits)
print(f"\nBatch OT completed. Results: {ot_results}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment