Skip to content

Instantly share code, notes, and snippets.

@st1vms
Last active September 4, 2024 12:18
Show Gist options
  • Save st1vms/847c3fe90ff050b3696c4a2073e1742b to your computer and use it in GitHub Desktop.
Save st1vms/847c3fe90ff050b3696c4a2073e1742b to your computer and use it in GitHub Desktop.
RSA utilities for generating Pub,Prv keys and encrypting/decrypting messages...
"""RSA Utilities"""
import random
import math
from dataclasses import dataclass
@dataclass(frozen=True)
class PubKey:
"""Public Key dataclass"""
e: int # Public exponent
n: int # Modulus
@dataclass(frozen=True)
class PrvKey:
"""Private Key dataclass"""
d: int # Private exponent
n: int # Modulus
def __is_prime(n: int, k: int = 5) -> bool:
""" Miller-Rabin primality test to determine if n is a prime number """
if n <= 1:
return False
if n <= 3:
return True
if n % 2 == 0 or n % 3 == 0:
return False
d = n - 1
r = 0
while d % 2 == 0:
d //= 2
r += 1
for _ in range(k):
a = random.randint(2, n - 2)
x = pow(a, d, n)
if x in (1, n - 1):
continue
for _ in range(r - 1):
x = pow(x, 2, n)
if x == n - 1:
break
else:
return False
return True
def __generate_prime(bits: int) -> int:
""" Generates a large prime number of specified bit size """
while True:
p = random.getrandbits(bits)
if p % 2 == 0:
p += 1
if __is_prime(p):
return p
def __modinv(a: int, m: int) -> int:
""" Computes the modular inverse of a under modulus m """
m0, x0, x1 = m, 0, 1
while a > 1:
q = a // m
m, a = a % m, m
x0, x1 = x1 - q * x0, x0
return x1 + m0 if x1 < 0 else x1
def get_keypair(bits: int = 1024) -> tuple[PubKey, PrvKey]:
""" Generates a public/private key pair for RSA encryption """
p = __generate_prime(bits)
q = __generate_prime(bits)
while p == q:
q = __generate_prime(bits)
n = p * q
phi_n = (p - 1) * (q - 1)
e = 65537 # Common choice for e
while math.gcd(e, phi_n) != 1:
e = random.randint(3, phi_n - 1)
d = __modinv(e, phi_n)
return PubKey(e, n), PrvKey(d, n)
def encrypt(plaintext: bytes, pub: PubKey) -> bytes:
""" Encrypts plaintext using the provided public key and returns a byte buffer """
encrypted = []
for byte in plaintext:
enc_byte = pow(byte, pub.e, pub.n)
encrypted.append(enc_byte.to_bytes((enc_byte.bit_length() + 7) // 8, byteorder='big'))
return b''.join(encrypted)
def decrypt(ciphertext: bytes, prv: PrvKey) -> bytes:
""" Decrypts ciphertext using the provided private key and returns a byte buffer """
decrypted = []
k = (prv.n.bit_length() + 7) // 8 # Size of each encrypted block in bytes
for i in range(0, len(ciphertext), k):
enc_block = int.from_bytes(ciphertext[i:i + k], byteorder='big')
dec_byte = pow(enc_block, prv.d, prv.n)
dec_chunk = dec_byte.to_bytes((prv.n.bit_length() + 7) // 8, byteorder='big')
decrypted.append(dec_chunk.lstrip(b'\x00')) # Left stripping padding bytes
return b''.join(decrypted)
if __name__ == "__main__":
# Generate key pairs for Alice and Bob
print("Generating keys for Alice...")
Apub, Aprv = get_keypair()
print(f"Alice's public key: {Apub}")
print(f"Alice's private key: {Aprv}")
print("\nGenerating keys for Bob...")
Bpub, Bprv = get_keypair()
print(f"Bob's public key: {Bpub}")
print(f"Bob's private key: {Bprv}")
# Alice sends a message to Bob
text = input("\nAlice, type your message for Bob:\n>> ").strip()
plaintext_bytes = text.encode('utf-8')
ENC_MSG = encrypt(plaintext_bytes, Bpub)
print(f"\nAlice encrypted her message using Bob's public key: {ENC_MSG.hex()}")
DEC_MSG = decrypt(ENC_MSG, Bprv)
print(f"\nBob decrypted the message using his private key: {DEC_MSG.decode('utf-8')}")
# Bob sends a message to Alice
text = input("\nBob, type your message for Alice:\n>> ").strip()
plaintext_bytes = text.encode('utf-8')
ENC_MSG = encrypt(plaintext_bytes, Apub)
print(f"\nBob encrypted his message using Alice's public key: {ENC_MSG.hex()}")
DEC_MSG = decrypt(ENC_MSG, Aprv)
print(f"\nAlice decrypted the message using her private key: {DEC_MSG.decode('utf-8')}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment