Last active
July 13, 2025 20:14
-
-
Save tarassh/43aec62ad64e73f574d131cef2e92b10 to your computer and use it in GitHub Desktop.
Oblivious Transfer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
P-256 Elliptic Curve Cryptography Utilities | |
This module provides a centralized implementation of P-256 elliptic curve | |
operations, constants, and utilities used across the MPC protocol implementations. | |
It eliminates code duplication and provides a consistent interface for all | |
ECC-related operations. | |
This module includes: | |
- P-256 curve parameters and constants | |
- Point representation and validation | |
- Elliptic curve point arithmetic (addition, doubling, scalar multiplication) | |
- Modular arithmetic utilities | |
- Point generation and validation functions | |
Used by: | |
- mta_protocol.py: Field modulus and modular arithmetic | |
- ectf_protocol.py: Complete elliptic curve operations and point validation | |
- oblivious_transfer.py: Curve parameters and legacy point operations | |
All implementations follow the NIST P-256 / secp256r1 standard specifications. | |
This centralized module was created by refactoring duplicated ECC code from | |
multiple protocol implementations to improve maintainability and consistency. | |
""" | |
import random | |
# P-256 elliptic curve parameters (NIST P-256 / secp256r1) | |
# These are the standardized parameters for the P-256 elliptic curve | |
# Curve equation: y² = x³ - 3x + b (mod p) | |
# Prime field modulus - same value used for both field operations and order | |
P256_P = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF | |
# Curve parameters | |
P256_A = -3 # Curve parameter a (equivalent to p-3 mod p for efficiency) | |
P256_B = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B | |
# Generator point coordinates | |
P256_GX = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296 | |
P256_GY = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5 | |
# Curve order (number of points on the curve) | |
P256_N = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551 | |
# For compatibility with existing code that uses MODULUS_Q | |
MODULUS_Q = P256_P | |
def mod_q(x): | |
""" | |
Modular reduction modulo the P-256 field modulus. | |
This function provides compatibility with existing MtA protocol code | |
while centralizing the modular arithmetic operation. | |
Args: | |
x (int): Integer value to reduce | |
Returns: | |
int: x mod q, where q is the P-256 field modulus | |
""" | |
return x % MODULUS_Q | |
def modinv(a, m=None): | |
""" | |
Compute modular multiplicative inverse. | |
Computes a^(-1) mod m such that (a * a^(-1)) ≡ 1 (mod m). | |
Uses Python's built-in pow function with exponent -1. | |
Args: | |
a (int): The value to invert | |
m (int, optional): The modulus. Defaults to P256_P if not provided | |
Returns: | |
int: The modular inverse of a modulo m | |
Raises: | |
ValueError: If a and m are not coprime (gcd(a,m) != 1) | |
Note: | |
For the P-256 field, this is guaranteed to work for all non-zero values | |
since P256_P is prime. | |
""" | |
if m is None: | |
m = P256_P | |
return pow(a, -1, m) | |
class ECPoint: | |
""" | |
Represents a point on the P-256 elliptic curve. | |
This class provides a lightweight but robust representation of elliptic | |
curve points with validation and utility methods. It supports both | |
affine coordinates (x, y) and the point at infinity. | |
Attributes: | |
x (int or None): The x-coordinate of the point (None for point at infinity) | |
y (int or None): The y-coordinate of the point (None for point at infinity) | |
is_infinity (bool): True if this represents the point at infinity | |
""" | |
def __init__(self, x=None, y=None): | |
""" | |
Initialize an elliptic curve point. | |
Args: | |
x (int, optional): The x-coordinate. None for point at infinity | |
y (int, optional): The y-coordinate. None for point at infinity | |
Note: | |
If either x or y is None, the point is treated as the point at infinity. | |
This follows the convention that (None, None) represents the identity element. | |
""" | |
if x is None or y is None: | |
self.x = None | |
self.y = None | |
self.is_infinity = True | |
else: | |
self.x = x % P256_P | |
self.y = y % P256_P | |
self.is_infinity = False | |
def __eq__(self, other): | |
"""Check if two points are equal.""" | |
if not isinstance(other, ECPoint): | |
return False | |
return ( | |
self.x == other.x | |
and self.y == other.y | |
and self.is_infinity == other.is_infinity | |
) | |
def __repr__(self): | |
"""String representation of the point.""" | |
if self.is_infinity: | |
return "ECPoint(∞)" | |
return f"ECPoint({self.x}, {self.y})" | |
def is_on_curve(self): | |
""" | |
Verify if this point lies on the P-256 elliptic curve. | |
Returns: | |
bool: True if the point is on the curve, False otherwise | |
Note: | |
The point at infinity is always considered to be on the curve. | |
""" | |
if self.is_infinity: | |
return True | |
return is_on_curve(self.x, self.y) | |
def is_on_curve(x, y): | |
""" | |
Verify if a point (x, y) lies on the P-256 elliptic curve. | |
The P-256 curve equation is: y² ≡ x³ - 3x + b (mod p) | |
where b = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B | |
Args: | |
x (int): x-coordinate of the point | |
y (int): y-coordinate of the point | |
Returns: | |
bool: True if the point is on the curve, False otherwise | |
Note: | |
This function checks the fundamental elliptic curve equation. | |
Points not satisfying this equation are invalid for ECC operations. | |
""" | |
# Reduce coordinates modulo p | |
x = x % P256_P | |
y = y % P256_P | |
# Compute left side: y² | |
left_side = (y * y) % P256_P | |
# Compute right side: x³ - 3x + b | |
x_cubed = (x * x * x) % P256_P | |
three_x = (3 * x) % P256_P | |
right_side = (x_cubed - three_x + P256_B) % P256_P | |
return left_side == right_side | |
def ec_point_add(p1, p2): | |
""" | |
Perform elliptic curve point addition on the P-256 curve. | |
Implements the complete elliptic curve point addition algorithm: | |
- Handles point at infinity cases | |
- Point doubling when P1 = P2 | |
- General point addition when P1 ≠ P2 | |
- Returns point at infinity for additive inverses | |
Args: | |
p1 (ECPoint): First point on the curve | |
p2 (ECPoint): Second point on the curve | |
Returns: | |
ECPoint: The sum P1 + P2 on the elliptic curve | |
Raises: | |
ValueError: If points are not on the curve | |
Note: | |
This is the reference implementation for elliptic curve point addition. | |
It handles all edge cases and special situations correctly. | |
""" | |
# Handle ECPoint objects | |
if not isinstance(p1, ECPoint): | |
if hasattr(p1, "x") and hasattr(p1, "y"): | |
p1 = ECPoint(p1.x, p1.y) | |
else: | |
raise ValueError("p1 must be an ECPoint or have x,y attributes") | |
if not isinstance(p2, ECPoint): | |
if hasattr(p2, "x") and hasattr(p2, "y"): | |
p2 = ECPoint(p2.x, p2.y) | |
else: | |
raise ValueError("p2 must be an ECPoint or have x,y attributes") | |
# Handle point at infinity cases | |
if p1.is_infinity: | |
return p2 | |
if p2.is_infinity: | |
return p1 | |
# Verify both points are on the curve | |
if not p1.is_on_curve(): | |
raise ValueError(f"Point P1({p1.x}, {p1.y}) is not on the P-256 curve") | |
if not p2.is_on_curve(): | |
raise ValueError(f"Point P2({p2.x}, {p2.y}) is not on the P-256 curve") | |
# Handle case where points have same x-coordinate | |
if p1.x == p2.x: | |
if p1.y == p2.y: | |
# Point doubling: P + P = 2P | |
# λ = (3x₁² + a) / (2y₁) where a = -3 for P-256 | |
numerator = (3 * p1.x * p1.x + P256_A) % P256_P | |
denominator = (2 * p1.y) % P256_P | |
lambda_slope = (numerator * modinv(denominator, P256_P)) % P256_P | |
else: | |
# Points are additive inverses: P + (-P) = O (point at infinity) | |
return ECPoint() # Point at infinity | |
else: | |
# General point addition: P₁ + P₂ where P₁ ≠ P₂ | |
# λ = (y₂ - y₁) / (x₂ - x₁) | |
numerator = (p2.y - p1.y) % P256_P | |
denominator = (p2.x - p1.x) % P256_P | |
lambda_slope = (numerator * modinv(denominator, P256_P)) % P256_P | |
# Compute result coordinates | |
# x₃ = λ² - x₁ - x₂ | |
# y₃ = λ(x₁ - x₃) - y₁ | |
x3 = (lambda_slope * lambda_slope - p1.x - p2.x) % P256_P | |
y3 = (lambda_slope * (p1.x - x3) - p1.y) % P256_P | |
return ECPoint(x3, y3) | |
def ec_point_double(p): | |
""" | |
Perform elliptic curve point doubling: compute 2P. | |
This is equivalent to ec_point_add(p, p) but more efficient. | |
Args: | |
p (ECPoint): Point to double | |
Returns: | |
ECPoint: The point 2P | |
""" | |
return ec_point_add(p, p) | |
def ec_scalar_mult(k, p): | |
""" | |
Perform elliptic curve scalar multiplication: compute k*P. | |
Uses the double-and-add algorithm for efficient computation. | |
Args: | |
k (int): The scalar multiplier | |
p (ECPoint): The point to multiply | |
Returns: | |
ECPoint: The point k*P | |
Note: | |
This implementation uses binary expansion of k and is | |
suitable for cryptographic applications. | |
""" | |
if not isinstance(p, ECPoint): | |
p = ECPoint(p.x, p.y) | |
if k == 0: | |
return ECPoint() # Point at infinity | |
if k < 0: | |
# For negative k, compute k*P = (-k)*(-P) | |
k = -k | |
p = ECPoint(p.x, (-p.y) % P256_P) # Point negation | |
result = ECPoint() # Point at infinity (identity element) | |
addend = p | |
while k: | |
if k & 1: | |
result = ec_point_add(result, addend) | |
addend = ec_point_double(addend) | |
k >>= 1 | |
return result | |
def generate_valid_point(): | |
""" | |
Generate a random valid point on the P-256 curve. | |
This function generates random x-coordinates and attempts to find | |
corresponding y-coordinates that satisfy the curve equation. | |
Returns: | |
ECPoint: A valid point on the P-256 elliptic curve | |
Note: | |
This is a simple implementation for testing. In production, | |
use proper cryptographic libraries for point generation. | |
The function will fallback to the generator point if random | |
generation fails after 1000 attempts. | |
""" | |
# Try random x-coordinates until we find one with a valid y | |
for _ in range(1000): # Limit attempts to avoid infinite loop | |
x = random.randint(1, P256_P - 1) | |
# Compute right side of curve equation: x³ - 3x + b | |
x_cubed = (x * x * x) % P256_P | |
three_x = (3 * x) % P256_P | |
y_squared = (x_cubed - three_x + P256_B) % P256_P | |
# Check if y_squared is a quadratic residue (has a square root) | |
# Using Tonelli-Shanks or simple power for square root | |
y = pow(y_squared, (P256_P + 1) // 4, P256_P) | |
# Verify this is actually a square root | |
if (y * y) % P256_P == y_squared: | |
return ECPoint(x, y) | |
# Fallback to generator point if random generation fails | |
return ECPoint(P256_GX, P256_GY) | |
def get_generator(): | |
""" | |
Get the P-256 generator point G. | |
Returns: | |
ECPoint: The standard P-256 generator point | |
""" | |
return ECPoint(P256_GX, P256_GY) | |
# Compatibility functions for existing code | |
def point_add(P, Q): | |
""" | |
Legacy compatibility function for point addition. | |
Converts tuple representation to ECPoint and back for compatibility | |
with existing code that uses (x, y) tuples. | |
Args: | |
P (tuple or None): First point as (x, y) or None for infinity | |
Q (tuple or None): Second point as (x, y) or None for infinity | |
Returns: | |
tuple or None: Result point as (x, y) or None for infinity | |
""" | |
if P is None: | |
p1 = ECPoint() | |
else: | |
p1 = ECPoint(P[0], P[1]) | |
if Q is None: | |
p2 = ECPoint() | |
else: | |
p2 = ECPoint(Q[0], Q[1]) | |
result = ec_point_add(p1, p2) | |
if result.is_infinity: | |
return None | |
return (result.x, result.y) | |
def inv_mod(x, m): | |
""" | |
Legacy compatibility function for modular inverse. | |
Args: | |
x (int): Value to invert | |
m (int): Modulus | |
Returns: | |
int: Modular inverse of x modulo m | |
""" | |
return modinv(x, m) | |
# Export commonly used constants and functions | |
__all__ = [ | |
# Constants | |
"P256_P", | |
"P256_A", | |
"P256_B", | |
"P256_GX", | |
"P256_GY", | |
"P256_N", | |
"MODULUS_Q", | |
# Core classes | |
"ECPoint", | |
# Elliptic curve operations | |
"ec_point_add", | |
"ec_point_double", | |
"ec_scalar_mult", | |
# Utility functions | |
"is_on_curve", | |
"modinv", | |
"mod_q", | |
"generate_valid_point", | |
"get_generator", | |
# Legacy compatibility | |
"point_add", | |
"inv_mod", | |
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Oblivious Transfer implementation using Elliptic Curve Cryptography and XOR. | |
This implements a simplified version of the Chou-Orlandi OT protocol: | |
"The Simplest Protocol for Oblivious Transfer" (2015) | |
https://eprint.iacr.org/2015/267.pdf | |
The protocol allows a sender with two messages (m0, m1) and a receiver with | |
a choice bit b to perform OT such that: | |
- Receiver learns m_b but nothing about m_(1-b) | |
- Sender learns nothing about the choice bit b | |
This implementation uses the P-256 curve parameters from the centralized ecc_utils | |
module for consistency with other cryptographic protocols in this project. | |
The OT protocol has been fully refactored to use centralized ECC operations: | |
- Generator point from get_generator() | |
- Scalar multiplication via ec_scalar_mult() with tuple compatibility wrapper | |
- Point validation using ECPoint.is_on_curve() | |
Dependencies: | |
- ecc_utils: Centralized P-256 curve parameters, point operations, and validation | |
""" | |
import hashlib | |
import secrets | |
import sys | |
import os | |
# Import centralized ECC utilities for P-256 curve operations | |
# Uses the refactored ecc_utils module for consistent curve parameters | |
sys.path.append(os.path.dirname(__file__)) | |
from ecc_utils import P256_P as p, point_add, ec_scalar_mult, get_generator, ECPoint | |
# Get P-256 generator point G using centralized ecc_utils | |
# Convert to tuple format for compatibility with legacy OT protocol functions | |
_generator_point = get_generator() | |
G = (_generator_point.x, _generator_point.y) | |
# Note: Using centralized modinv function from ecc_utils for modular inverse operations | |
# This ensures consistency with other ECC operations across the project | |
def point_multiply(point, scalar): | |
""" | |
Multiply a point by a scalar using the centralized ec_scalar_mult function. | |
This is a compatibility wrapper that converts between tuple representation | |
(used by OT protocol) and ECPoint objects (used by ecc_utils). | |
Args: | |
point (tuple): Point as (x, y) coordinates or None for point at infinity | |
scalar (int): Scalar multiplier | |
Returns: | |
tuple or None: Result point as (x, y) or None for point at infinity | |
""" | |
if scalar == 0 or point is None: | |
return None # Point at infinity | |
if scalar == 1: | |
return point | |
if scalar < 0: | |
# Handle negative scalars by negating the point | |
scalar = -scalar | |
point = (point[0], (-point[1]) % p) | |
# Create ECPoint object and use centralized scalar multiplication | |
ec_point = ECPoint(point[0], point[1]) | |
result_point = ec_scalar_mult(scalar, ec_point) | |
# Convert back to tuple format | |
if result_point.is_infinity: | |
return None | |
return (result_point.x, result_point.y) | |
def point_to_bytes(point): | |
"""Convert a point to bytes (uncompressed format).""" | |
if point is None: | |
return b"\x00" * 65 # Point at infinity | |
x, y = point | |
# Uncompressed format: 0x04 || x || y | |
return b"\x04" + x.to_bytes(32, "big") + y.to_bytes(32, "big") | |
def bytes_to_point(point_bytes): | |
""" | |
Convert bytes to a point and verify it lies on the P-256 curve. | |
This function validates that the decoded point satisfies the P-256 curve equation | |
using the centralized is_on_curve function from ecc_utils for consistency. | |
Args: | |
point_bytes (bytes): Point in uncompressed format (65 bytes: 0x04 || x || y) | |
Returns: | |
tuple: Point as (x, y) coordinates | |
Raises: | |
ValueError: If point encoding is invalid or point is not on curve | |
""" | |
if len(point_bytes) != 65 or point_bytes[0] != 0x04: | |
raise ValueError("Invalid point encoding") | |
x = int.from_bytes(point_bytes[1:33], "big") | |
y = int.from_bytes(point_bytes[33:65], "big") | |
try: | |
# Create ECPoint which automatically validates the point is on curve | |
test_point = ECPoint(x, y) | |
if not test_point.is_on_curve(): | |
raise ValueError("Point not on curve") | |
except Exception as exc: | |
raise ValueError("Point not on curve") from exc | |
return (x, y) | |
def hash_point_to_key(point, label=b""): | |
"""Hash a point to derive a symmetric key of sufficient length.""" | |
point_bytes = point_to_bytes(point) | |
# Use multiple rounds of hashing to get more key material | |
key = hashlib.sha256(point_bytes + label).digest() | |
# Expand key to 64 bytes to handle longer messages | |
key += hashlib.sha256(key + b"expand").digest() | |
return key | |
def xor_bytes(data1, data2): | |
"""XOR two byte arrays.""" | |
return bytes(a ^ b for a, b in zip(data1, data2)) | |
def pad_message(message): | |
"""Simple padding that ensures consistent length.""" | |
if isinstance(message, str): | |
message = message.encode("utf-8") | |
# Simple approach: just make sure we have enough bytes for XOR | |
return message | |
def unpad_message(padded_message): | |
"""Simple unpadding - just strip trailing nulls.""" | |
return padded_message.rstrip(b"\x00") | |
class OTSender: | |
"""Sender side of the Oblivious Transfer protocol.""" | |
def __init__(self): | |
self.private_scalar = None | |
self.public_point = None | |
def step1_generate_public_point(self): | |
""" | |
Step 1: Generate random scalar a and compute A = a*G. | |
Returns the public point A to send to receiver. | |
""" | |
# Generate random scalar in [1, p-1] | |
self.private_scalar = secrets.randbelow(p - 1) + 1 | |
# Compute A = a * G | |
self.public_point = point_multiply(G, self.private_scalar) | |
return point_to_bytes(self.public_point) | |
def step3_encrypt_messages(self, B_bytes, message0, message1): | |
""" | |
Step 3: Receive B from receiver and send encrypted messages. | |
For the Chou-Orlandi protocol: | |
- k0 = Hash(a * B) if receiver chose 0 | |
- k1 = Hash(a * (B - A)) = Hash(a * B - a * A) if receiver chose 1 | |
Since receiver computes: | |
- B = b*G if choice = 0, so a*B = a*b*G | |
- B = b*G + A if choice = 1, so a*B = a*b*G + a*A and a*(B-A) = a*b*G | |
So both cases result in the same shared secret a*b*G for the chosen message. | |
""" | |
if self.private_scalar is None: | |
raise ValueError("Must call step1_generate_public_point first") | |
# Convert B from bytes to point | |
B_point = bytes_to_point(B_bytes) | |
# Compute shared points for both possible keys | |
# k0: This is what receiver can compute if choice=0 (B = b*G) | |
shared_point_0 = point_multiply(B_point, self.private_scalar) # a*B | |
k0 = hash_point_to_key(shared_point_0, b"k0") | |
# k1: This is what receiver can compute if choice=1 (B = b*G + A) | |
# We need a*(B-A) = a*B - a*A | |
# For choice=1, receiver sends B = b*G + A | |
# Sender computes a*(B-A) = a*b*G | |
neg_A = (self.public_point[0], (-self.public_point[1]) % p) | |
B_minus_A = point_add(B_point, neg_A) | |
shared_point_1 = point_multiply(B_minus_A, self.private_scalar) | |
k1 = hash_point_to_key(shared_point_1, b"k1") | |
# Prepare messages as bytes | |
m0_padded = pad_message(message0) | |
m1_padded = pad_message(message1) | |
# Encrypt messages with keys of appropriate length | |
c0 = xor_bytes(m0_padded, k0[: len(m0_padded)]) | |
c1 = xor_bytes(m1_padded, k1[: len(m1_padded)]) | |
return c0, c1 | |
class OTReceiver: | |
"""Receiver side of the Oblivious Transfer protocol.""" | |
def __init__(self, choice_bit): | |
self.choice_bit = bool(choice_bit) | |
self.private_scalar = None | |
self.sender_public_point = None | |
def step2_compute_B(self, A_bytes): | |
""" | |
Step 2: Receive A from sender and compute B based on choice bit. | |
For Chou-Orlandi protocol: | |
- If choice = 0: B = b*G | |
- If choice = 1: B = b*G + A | |
This ensures that: | |
- If choice = 0: receiver can compute b*A = a*b*G | |
- If choice = 1: receiver can compute b*A = a*b*G, and sender uses a*(B-A) = a*b*G | |
""" | |
# Convert A from bytes to point | |
self.sender_public_point = bytes_to_point(A_bytes) | |
# Generate random scalar b | |
self.private_scalar = secrets.randbelow(p - 1) + 1 | |
# Compute B based on choice bit | |
if self.choice_bit == 0: | |
# B = b * G | |
B_point = point_multiply(G, self.private_scalar) | |
else: | |
# B = b * G + A | |
bG = point_multiply(G, self.private_scalar) | |
B_point = point_add(bG, self.sender_public_point) | |
return point_to_bytes(B_point) | |
def step4_decrypt_message(self, c0, c1): | |
""" | |
Step 4: Decrypt the chosen message using the shared key. | |
Receiver computes b*A = a*b*G regardless of choice bit. | |
This matches either: | |
- k0 (if choice=0) when sender computes a*B = a*b*G | |
- k1 (if choice=1) when sender computes a*(B-A) = a*b*G | |
""" | |
if self.private_scalar is None or self.sender_public_point is None: | |
raise ValueError("Must call step2_compute_B first") | |
# Compute shared key: b * A = a*b*G (same for both choice bits) | |
shared_point = point_multiply(self.sender_public_point, self.private_scalar) | |
# Use the appropriate label based on choice bit to match sender's key derivation | |
if self.choice_bit == 0: | |
k = hash_point_to_key(shared_point, b"k0") | |
ciphertext = c0 | |
else: | |
k = hash_point_to_key(shared_point, b"k1") | |
ciphertext = c1 | |
# Use the key to decrypt | |
decrypted = xor_bytes(ciphertext, k[: len(ciphertext)]) | |
# Remove any padding | |
return unpad_message(decrypted) | |
def oblivious_transfer_demo(): | |
""" | |
Demonstration of the Oblivious Transfer protocol. | |
""" | |
print("Chou-Orlandi Oblivious Transfer Demo") | |
print("=" * 40) | |
# Setup | |
message0 = "Secret message 0: Alice's private key" | |
message1 = "Secret message 1: Bob's private key" | |
choice_bit = 1 # Receiver wants message 1 | |
print("Sender has messages:") | |
print(f" m0 = '{message0}'") | |
print(f" m1 = '{message1}'") | |
print(f"Receiver choice bit: {choice_bit}") | |
print() | |
# Initialize parties | |
sender = OTSender() | |
receiver = OTReceiver(choice_bit) | |
# Step 1: Sender generates A | |
print("Step 1: Sender generates point A") | |
A_bytes = sender.step1_generate_public_point() | |
print(f"A = {A_bytes.hex()}") | |
print() | |
# Step 2: Receiver computes B | |
print("Step 2: Receiver computes point B based on choice bit") | |
B_bytes = receiver.step2_compute_B(A_bytes) | |
print(f"B = {B_bytes.hex()}") | |
print() | |
# Step 3: Sender encrypts messages | |
print("Step 3: Sender encrypts both messages") | |
c0, c1 = sender.step3_encrypt_messages(B_bytes, message0, message1) | |
print(f"c0 = {c0.hex()}") | |
print(f"c1 = {c1.hex()}") | |
print() | |
# Step 4: Receiver decrypts chosen message | |
print("Step 4: Receiver decrypts chosen message") | |
decrypted_message = receiver.step4_decrypt_message(c0, c1) | |
print(f"Decrypted message: '{decrypted_message.decode('utf-8')}'") | |
print() | |
# Verify correctness | |
expected_message = message1 if choice_bit else message0 | |
protocol_success = decrypted_message.decode("utf-8") == expected_message | |
print(f"Protocol success: {protocol_success}") | |
return protocol_success | |
def batch_oblivious_transfer(msg_pairs, choices): | |
""" | |
Perform multiple OT operations in batch. | |
Args: | |
msg_pairs: List of (message0, message1) tuples | |
choices: List of choice bits for each OT | |
Returns: | |
List of decrypted messages | |
""" | |
if len(msg_pairs) != len(choices): | |
raise ValueError("Number of message pairs must match number of choice bits") | |
batch_results = [] | |
for i, ((m0, m1), choice) in enumerate(zip(msg_pairs, choices)): | |
print(f"\nOT {i+1}:") | |
print(f" Messages: ('{m0}', '{m1}')") | |
print(f" Choice: {choice}") | |
sender = OTSender() | |
receiver = OTReceiver(choice) | |
# Execute protocol | |
A_bytes = sender.step1_generate_public_point() | |
B_bytes = receiver.step2_compute_B(A_bytes) | |
c0, c1 = sender.step3_encrypt_messages(B_bytes, m0, m1) | |
decrypted = receiver.step4_decrypt_message(c0, c1) | |
result = decrypted.decode("utf-8") | |
batch_results.append(result) | |
print(f" Result: '{result}'") | |
return batch_results | |
if __name__ == "__main__": | |
# Run basic demo | |
success = oblivious_transfer_demo() | |
if success: | |
print("\n" + "=" * 50) | |
print("Batch OT Demo") | |
print("=" * 50) | |
# Batch demo | |
message_pairs = [ | |
("Alice's secret", "Bob's secret"), | |
("Key fragment 1", "Key fragment 2"), | |
("Left branch", "Right branch"), | |
] | |
choice_bits = [0, 1, 0] | |
ot_results = batch_oblivious_transfer(message_pairs, choice_bits) | |
print(f"\nBatch OT completed. Results: {ot_results}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment