Skip to content

Instantly share code, notes, and snippets.

@ppoffice
Last active December 19, 2023 18:13
Show Gist options
  • Save ppoffice/e10e0a418d5dafdd5efe9495e962d3d2 to your computer and use it in GitHub Desktop.
Save ppoffice/e10e0a418d5dafdd5efe9495e962d3d2 to your computer and use it in GitHub Desktop.
Textbook/RAW RSA & RSA with OAEP+SHA1+MGF1 Python Implementation
from typing import Tuple
import pyasn1.codec.der.encoder
import pyasn1.type.univ
import base64
import rsa
def private_key_pem(n: int, e: int, d: int, p: int, q: int, dP: int, dQ: int, qInv: int) -> str:
'''Create a private key PEM file
https://0day.work/how-i-recovered-your-private-key-or-why-small-keys-are-bad/'''
template = '-----BEGIN RSA PRIVATE KEY-----\n{}-----END RSA PRIVATE KEY-----\n'
seq = pyasn1.type.univ.Sequence()
for x in [0, n, e, d, p, q, dP, dQ, qInv]:
seq.setComponentByPosition(len(seq), pyasn1.type.univ.Integer(x))
der = pyasn1.codec.der.encoder.encode(seq)
return template.format(base64.encodestring(der).decode('ascii'))
def public_key_pem(n: int, e: int) -> str:
'''Create a public key PEM file'''
template = '-----BEGIN RSA PUBLIC KEY-----\n{}-----END RSA PUBLIC KEY-----\n'
seq = pyasn1.type.univ.Sequence()
for x in [n, e]:
seq.setComponentByPosition(len(seq), pyasn1.type.univ.Integer(x))
der = pyasn1.codec.der.encoder.encode(seq)
return template.format(base64.encodestring(der).decode('ascii'))
def create_key_pair(p: int, q: int, e: int = None) -> Tuple[str, str]:
pub, prv = rsa.keygen(p, q, e)
e, n = pub
d, _ = prv
dP = rsa.modinv(e, p - 1)
dQ = rsa.modinv(e, q - 1)
qInv = rsa.modinv(q, p)
pub_pem = public_key_pem(n, e)
prv_pem = private_key_pem(n, e, d, p, q, dP, dQ, qInv)
return pub_pem, prv_pem
import base64
import rsa
import asn1
if __name__ == '__main__':
e = 0x010001
p = 0xCAA8F25E146F81FB0C31FB9FC98C5A4EDB25829EAA97B1B07C0761FE4E185D9EB886A8EC478A4BCCBF43A2AB3300A972074B1BACF1BEB731C1C096F9573A02D9
q = 0xB0A1DD2EAB28ED07BE20658BF6D0FAA0CF395352746AB256A251F95AAA558E0C575866719821815A64F4DE0BE62E89D2F5E99805AFB1596C2755CCB96C4D9C63
pub_key, prv_key = rsa.keygen(p, q, e)
pub_pem, prv_pem = asn1.create_key_pair(p, q, e)
print('Private key is:')
print(prv_pem)
cipher_text = rsa.encrypt_oaep('hello world!'.encode('ascii'), pub_key)
print('RSA-OAEP Encrypted text is:')
print(base64.encodestring(cipher_text).decode('ascii'))
print('RSA-OAEP Decrypted text is:')
plain_text = rsa.decrypt_oaep(cipher_text, prv_key)
print(plain_text.decode('ascii'))
print()
cipher_text = rsa.encrypt_raw('hello world!'.encode('ascii'), pub_key)
print('RSA-RAW Encrypted text is:')
print(base64.encodestring(cipher_text).decode('ascii'))
print('RSA-RAW Decrypted text is:')
plain_text = rsa.decrypt_raw(cipher_text, prv_key)
print(plain_text.decode('ascii'))
from math import sqrt, ceil
import os
import copy
import hashlib
import random
from typing import Tuple, Callable
Key = Tuple[int, int]
def euclid(a: int, b: int) -> int:
'''Calculate the GCD of a and b using Euclid's algorithm'''
while b != 0:
a, b = b, a % b
return a
def extend_euclid(a: int, b: int) -> int:
'''Use Euclid's extended algorithm to calculate integer x, y that satisfies
a * x + b * y = euclid(a, b)'''
if b == 0:
return 1, 0, a
else:
x, y, q = extend_euclid(b, a % b)
return y, x - (a // b) * y, q
def modinv(a: int, b: int) -> int:
'''Calculate the Modular Inverse'''
# d * e = 1 (mod phi) <=> d * e + k * phi = 1
x, y, q = extend_euclid(a, b)
if q != 1:
return None
else:
return x % b
def is_prime_trial_division(n: int) -> bool:
'''Test if a given integer n is a prime number using trial division'''
if n == 2:
return True
if n < 2 or n % 2 == 0:
return False
for i in range(3, ceil(sqrt(n)), 2):
if n % i == 0:
return False
return True
# prime numbers with 1000
known_primes = [2] + \
[x for x in range(3, 1000, 2) if is_prime_trial_division(x)]
def is_prime_miller_rabin(n: int, precision: int) -> bool:
'''Test if a given integer n is a prime number using miller-rabin test
https://rosettacode.org/wiki/Miller%E2%80%93Rabin_primality_test#Python:_Probably_correct_answers
'''
def try_composite(a, d, s):
if pow(a, d, n) == 1:
return False
for i in range(s):
if pow(a, pow(2, i) * d, n) == n - 1:
return False
return True
if n % 2 == 0:
return False
d, s = n - 1, 0
while d % 2 == 0:
d, s = d >> 1, s + 1
# Returns exact according to http://primes.utm.edu/prove/prove2_3.html
if n < 1373653:
return not any(try_composite(a, d, s) for a in known_primes[:2])
if n < 25326001:
return not any(try_composite(a, d, s) for a in known_primes[:3])
if n < 118670087467:
if n == 3215031751:
return False
return not any(try_composite(a, d, s) for a in known_primes[:4])
if n < 2152302898747:
return not any(try_composite(a, d, s) for a in known_primes[:5])
if n < 3474749660383:
return not any(try_composite(a, d, s) for a in known_primes[:6])
if n < 341550071728321:
return not any(try_composite(a, d, s) for a in known_primes[:7])
return not any(try_composite(a, d, s) for a in known_primes[:precision])
def is_prime(n: int, precision: int = 16) -> bool:
'''Test if a given integer is a prime number'''
assert n > 0
if n in known_primes:
return True
elif n < 100000:
return is_prime_trial_division(n)
else:
return is_prime_miller_rabin(n, precision)
def keygen(p: int, q: int, e: int = None) -> Tuple[Key, Key]:
'''Create public key (exponenet e, modulus n) and private key
(exponent d, modulus n)'''
assert is_prime(p) and is_prime(q)
assert p != q
n = p * q
phi = (p - 1) * (q - 1)
if e != None:
assert euclid(phi, e) == 1
else:
while True:
e = random.randrange(1, phi)
if euclid(e, phi) == 1:
break
d = modinv(e, phi)
return ((e, n), (d, n))
def get_key_len(key: Key) -> int:
'''Get the number of octets of the public/private key modulus'''
_, n = key
return n.bit_length() // 8
def os2ip(x: bytes) -> int:
'''Converts an octet string to a nonnegative integer'''
return int.from_bytes(x, byteorder='big')
def i2osp(x: int, xlen: int) -> bytes:
'''Converts a nonnegative integer to an octet string of a specified length'''
return x.to_bytes(xlen, byteorder='big')
def sha1(m: bytes) -> bytes:
'''SHA-1 hash function'''
hasher = hashlib.sha1()
hasher.update(m)
return hasher.digest()
def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes:
'''MGF1 mask generation function with SHA-1'''
t = b''
hlen = len(f_hash(b''))
for c in range(0, ceil(mlen / hlen)):
_c = i2osp(c, 4)
t += f_hash(seed + _c)
return t[:mlen]
def xor(data: bytes, mask: bytes) -> bytes:
'''Byte-by-byte XOR of two byte arrays'''
masked = b''
ldata = len(data)
lmask = len(mask)
for i in range(max(ldata, lmask)):
if i < ldata and i < lmask:
masked += (data[i] ^ mask[i]).to_bytes(1, byteorder='big')
elif i < ldata:
masked += data[i].to_bytes(1, byteorder='big')
else:
break
return masked
def oaep_encode(m: bytes, k: int, label: bytes = b'',
f_hash: Callable = sha1, f_mgf: Callable = mgf1) -> bytes:
'''EME-OAEP encoding'''
mlen = len(m)
lhash = f_hash(label)
hlen = len(lhash)
ps = b'\x00' * (k - mlen - 2 * hlen - 2)
db = lhash + ps + b'\x01' + m
seed = os.urandom(hlen)
db_mask = f_mgf(seed, k - hlen - 1, f_hash)
masked_db = xor(db, db_mask)
seed_mask = f_mgf(masked_db, hlen, f_hash)
masked_seed = xor(seed, seed_mask)
return b'\x00' + masked_seed + masked_db
def oaep_decode(c: bytes, k: int, label: bytes = b'',
f_hash: Callable = sha1, f_mgf: Callable = mgf1) -> bytes:
'''EME-OAEP decoding'''
clen = len(c)
lhash = f_hash(label)
hlen = len(lhash)
_, masked_seed, masked_db = c[:1], c[1:1 + hlen], c[1 + hlen:]
seed_mask = f_mgf(masked_db, hlen, f_hash)
seed = xor(masked_seed, seed_mask)
db_mask = f_mgf(seed, k - hlen - 1, f_hash)
db = xor(masked_db, db_mask)
_lhash = db[:hlen]
assert lhash == _lhash
i = hlen
while i < len(db):
if db[i] == 0:
i += 1
continue
elif db[i] == 1:
i += 1
break
else:
raise Exception()
m = db[i:]
return m
def encrypt(m: int, public_key: Key) -> int:
'''Encrypt an integer using RSA public key'''
e, n = public_key
return pow(m, e, n)
def encrypt_raw(m: bytes, public_key: Key) -> bytes:
'''Encrypt a byte array without padding'''
k = get_key_len(public_key)
c = encrypt(os2ip(m), public_key)
return i2osp(c, k)
def encrypt_oaep(m: bytes, public_key: Key) -> bytes:
'''Encrypt a byte array with OAEP padding'''
hlen = 20 # SHA-1 hash length
k = get_key_len(public_key)
assert len(m) <= k - hlen - 2
return encrypt_raw(oaep_encode(m, k), public_key)
def decrypt(c: int, private_key: Key) -> int:
'''Decrypt an integer using RSA private key'''
d, n = private_key
return pow(c, d, n)
def decrypt_raw(c: bytes, private_key: Key) -> bytes:
'''Decrypt a cipher byte array without padding'''
k = get_key_len(private_key)
m = decrypt(os2ip(c), private_key)
return i2osp(m, k)
def decrypt_oaep(c: bytes, private_key: Key) -> bytes:
'''Decrypt a cipher byte array with OAEP padding'''
k = get_key_len(private_key)
hlen = 20 # SHA-1 hash length
assert len(c) == k
assert k >= 2 * hlen + 2
return oaep_decode(decrypt_raw(c, private_key), k)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment