Created
June 29, 2025 20:42
-
-
Save jakiki6/ddb41194c289261a4cbd153a586213ca to your computer and use it in GitHub Desktop.
Way to compress Kyber/ML-KEM public keys. Each coefficient is less than 3329 while being encoded with 12 bits (2^12 = 4096). This wastes a bit of space that can be reduced by encoding all coefficients as an integer and storing that integer. This should reduce the public key size by 2.5% (log2(3329) = 11.7 < 12)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def compress_pub(pub): | |
match len(pub): | |
case 800: | |
# 512 | |
count = 512 | |
size = 749 | |
case 1184: | |
# 768 | |
count = 768 | |
size = 1124 | |
case 1568: | |
# 1024 | |
count = 1024 | |
size = 1498 | |
case _: | |
raise ValueError(f"Unknown instance with public key size of {len(pub)} bytes") | |
c = int.from_bytes(pub[:count * 3], "little") | |
coff = 0 | |
for i in range(0, count): | |
coff *= 3329 | |
coff += (c & 0b111111111111) | |
c >>= 12 | |
return coff.to_bytes(size, "little") + pub[count + (count >> 1):] | |
def decompress_pub(pub): | |
match len(pub): | |
case 781: | |
# 512 | |
count = 512 | |
size = 749 | |
case 1156: | |
# 768 | |
count = 768 | |
size = 1124 | |
case 1530: | |
# 1024 | |
count = 1024 | |
size = 1498 | |
case _: | |
raise ValueError(f"Unknown instance with public key size of {len(pub)} bytes") | |
c = int.from_bytes(pub[:size], "little") | |
coff = 0 | |
for i in range(0, count): | |
coff <<= 12 | |
coff |= c % 3329 | |
c //= 3329 | |
return coff.to_bytes(count + (count >> 1), "little") + pub[size:] | |
if __name__ == "__main__": | |
import random | |
from kyber_py.ml_kem import ML_KEM_512, ML_KEM_768, ML_KEM_1024 | |
scheme = [ML_KEM_512, ML_KEM_768, ML_KEM_1024] | |
while True: | |
pub, sec = random.choice(scheme).keygen() | |
cpub = compress_pub(pub) | |
rpub = decompress_pub(cpub) | |
if pub != rpub: | |
print(f"Failure:\n{pub.hex()}\n{rpub.hex()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment