Skip to content

Instantly share code, notes, and snippets.

@maple3142
Created November 24, 2024 08:31
Show Gist options
  • Save maple3142/0d8975dd53a28f6e9ce2e13e355041f1 to your computer and use it in GitHub Desktop.
Save maple3142/0d8975dd53a28f6e9ce2e13e355041f1 to your computer and use it in GitHub Desktop.
SECCON CTF 13 (2024) Quals - seqr
from sage.all import *
from lll_cvp import *
import os
import signal
import time
import random
from secrets import randbelow
from pwn import process, remote
from hashlib import sha1
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from gmpy2 import legendre
from fastecdsa.curve import secp256k1
from fastecdsa.keys import gen_keypair
from fastecdsa.point import Point
from tqdm import tqdm
from functools import partial
from concurrent.futures import ProcessPoolExecutor
class PRNG:
"""Legendre PRF is believed to be secure
ex. https://link.springer.com/chapter/10.1007/0-387-34799-2_13
"""
def __init__(self, initial_state: int, p: int) -> None:
self._state = initial_state
self.p = p
def __call__(self, n_bit: int) -> int:
out = 0
for _ in range(n_bit):
out <<= 1
tmp = legendre(self._state, self.p)
out |= (1 + tmp) // 2 if tmp != 0 else 1
self._state += 1
self._state %= self.p
return out
# fmt: off
ps = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193]
# fmt: on
p = prod(ps)
reclen = 3 * 5 * 7
known_seq = [int(gcd(i, reclen) > 1) for i in range(0, reclen)]
def compute_ks_known_bits(n, a_mod_rec):
ks_known_bits = [["?"] * 256 for _ in range(n)]
ctr = a_mod_rec
for i in range(n):
ksb = ks_known_bits[i]
for j in range(256):
if known_seq[ctr] == 1:
ksb[j] = "1"
ctr += 1
if ctr == reclen:
ctr = 0
ks_known_bits[i] = "".join(reversed(ksb)) # lsb first
return ks_known_bits
def compute_ks_info(ks_known_bits, min_chunk_size):
def safe_str_index(s, sub, start=0):
try:
return s.index(sub, start)
except ValueError:
return None
ks_info = []
for ksb in ks_known_bits:
info = []
i = ksb.index("1")
while True:
ni = safe_str_index(ksb, "?", i)
if ni is None:
break
if ni - i >= min_chunk_size:
info.append((i, ni - i, int(ksb[i:ni], 2))) # i, length, value
i = safe_str_index(ksb, "1", ni)
if i is None:
break
if i is not None and i < len(ksb) and len(ksb) - i >= min_chunk_size:
info.append((i, len(ksb) - i, int(ksb[i:], 2)))
ks_info.append(info)
return ks_info
q = secp256k1.q
PR = PolynomialRing(Zmod(q), "w", 400)
def attack(ms, rs, ss, guess_a_mod_rec):
n = len(ms)
ks_known_bits = compute_ks_known_bits(n, guess_a_mod_rec)
ks_info = compute_ks_info(ks_known_bits, 1)
# keep only prefix and suffix to keep the problem in HNP (not EHNP)
ks_info = [
[(i, l, v) for i, l, v in info if i == 0 or (i + l) == 256] for info in ks_info
]
# and keep only the ones with at least 6 bits
to_keep = [sum(l for _, l, _ in info) >= 6 for info in ks_info]
ms = [m for m, keep in zip(ms, to_keep, strict=True) if keep]
rs = [r for r, keep in zip(rs, to_keep, strict=True) if keep]
ss = [s for s, keep in zip(ss, to_keep, strict=True) if keep]
ks_info = [info for info, keep in zip(ks_info, to_keep, strict=True) if keep]
syms_pool = list(PR.gens())
bounds = {}
ks_sym = []
for info in ks_info:
k_sym = 0
ctr = 0
for i, l, v in info:
if i > ctr:
sm = syms_pool.pop()
bounds[sm] = 2 ** (i - ctr)
k_sym += sm * 2**ctr
ctr = i
k_sym += v * 2**i
ctr += l
if ctr < 256:
sm = syms_pool.pop()
bounds[sm] = 2 ** (256 - ctr)
k_sym += sm * 2**ctr
ks_sym.append(k_sym)
d = syms_pool.pop()
eqs = []
for m, r, s, k_sym in zip(ms, rs, ss, ks_sym):
eq = s * k_sym - (m + d * r)
eqs.append(eq)
bounds[d] = q
M, monos = polynomials_to_matrix(eqs)
assert monos[-1] == 1
A = M[:, :-1].dense_matrix()
b = -vector(M[:, -1])
lb = [0] * len(monos[:-1])
ub = [bounds[m] for m in monos[:-1]]
# now we want to find x that A*x=b and lb<=x<=ub
# first construct the solution space:
s0 = A.solve_right(b)
ker = A.right_kernel_matrix().echelon_form()
# find lb<=s0+?*ker<=ub (mod q)
# -> lb-s0<=?*ker<=ub-s0 (mod q)
# HNP-like lattice
# [1,?]
# [0,q]
L = ker.change_ring(ZZ).stack(matrix.identity(ker.ncols())[ker.nrows() :] * q)
lbx = (vector(lb) - s0.change_ring(ZZ)).list()
ubx = (vector(ub) - s0.change_ring(ZZ)).list()
sol = solve_inequality(
L,
lbx,
ubx,
cvp=partial(kannan_cvp, reduction=LLL),
)
if all([a <= x <= b for a, x, b in zip(lbx, sol, ubx)]):
bound_sol = s0 + ZZ(sol[0]) * ker[0]
assert A * bound_sol == b
return int(bound_sol[0])
# io = process(["python3", "server2.py"])
io = remote("seqr.seccon.games", 13337)
io.sendlineafter(b">", f"{p:x}".encode())
io.sendlineafter(b">", b"2")
io.recvuntil(b"pubkey = ")
pk_bytes = bytes.fromhex(io.recvline().strip().decode())
Px = int.from_bytes(pk_bytes[:32], "big")
Py = int.from_bytes(pk_bytes[32:], "big")
pk = Point(Px, Py, curve=secp256k1)
print("pk")
print(pk)
io.sendlineafter(b">", b"3")
io.recvuntil(b"enc = ")
flag_enc = bytes.fromhex(io.recvline().strip().decode())
def sign(msg: bytes):
io.sendlineafter(b">", b"1")
io.sendlineafter(b">", msg.hex().encode())
io.recvuntil(b"signature = ")
sig = bytes.fromhex(io.recvline().strip().decode())
z = int.from_bytes(sha1(msg).digest(), "big")
r = int.from_bytes(sig[:32], "big")
s = int.from_bytes(sig[32:], "big")
return z, r, s
def sign_batch(msgs: list[bytes]):
for msg in msgs:
io.sendline(b"1")
io.sendline(msg.hex().encode())
for msg in msgs:
io.recvuntil(b"signature = ")
sig = bytes.fromhex(io.recvline().strip().decode())
z = int.from_bytes(sha1(msg).digest(), "big")
r = int.from_bytes(sig[:32], "big")
s = int.from_bytes(sig[32:], "big")
yield z, r, s
n = 3000
msgs = [str(i).encode() for i in range(n)]
# sigs = [sign(msg) for msg in tqdm(msgs)]
sigs = list(tqdm(sign_batch(msgs)))
ms = [z for z, r, s in sigs]
rs = [r for z, r, s in sigs]
ss = [s for z, r, s in sigs]
print("got signatures")
def get_mod(bits, p):
# check if bits[i::p] is all 1
# then return -i % p
for i in range(p):
while i < len(bits):
if bits[i] != 1:
break
i += p
else:
return -i % p
return -1
with ProcessPoolExecutor(max_workers=8) as executor:
start = time.time()
futures = []
for i in range(0, reclen):
futures.append((i, executor.submit(attack, ms, rs, ss, i)))
for i, future in futures:
sk = future.result()
print("brute", i, sk)
print("current", time.time() - start)
if sk and sk * secp256k1.G == pk:
print("FOUND SK !!")
break
for _, future in futures[i + 1 :]:
future.cancel()
# k=(z+rd)/s
ks = [(z + r * sk) * pow(s, -1, q) % q for z, r, s in sigs]
bits = list(map(int, "".join(f"{k:0256b}" for k in ks)))
rems = [get_mod(bits, p) for p in ps]
a_rec = int(crt(rems, ps))
print("a_rec", a_rec)
key = (sk ^ a_rec).to_bytes(32, "big")
flag = AES.new(key, AES.MODE_ECB).decrypt(flag_enc)
print("flag", flag)
# SECCON{17_15_un4cc3p74bl3_7h47_l363ndr3_c4n_u53_c0mp05173_numb3r5}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment