Created
November 24, 2024 08:31
-
-
Save maple3142/0d8975dd53a28f6e9ce2e13e355041f1 to your computer and use it in GitHub Desktop.
SECCON CTF 13 (2024) Quals - seqr
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sage.all import * | |
from lll_cvp import * | |
import os | |
import signal | |
import time | |
import random | |
from secrets import randbelow | |
from pwn import process, remote | |
from hashlib import sha1 | |
from Crypto.Cipher import AES | |
from Crypto.Util.Padding import pad | |
from gmpy2 import legendre | |
from fastecdsa.curve import secp256k1 | |
from fastecdsa.keys import gen_keypair | |
from fastecdsa.point import Point | |
from tqdm import tqdm | |
from functools import partial | |
from concurrent.futures import ProcessPoolExecutor | |
class PRNG: | |
"""Legendre PRF is believed to be secure | |
ex. https://link.springer.com/chapter/10.1007/0-387-34799-2_13 | |
""" | |
def __init__(self, initial_state: int, p: int) -> None: | |
self._state = initial_state | |
self.p = p | |
def __call__(self, n_bit: int) -> int: | |
out = 0 | |
for _ in range(n_bit): | |
out <<= 1 | |
tmp = legendre(self._state, self.p) | |
out |= (1 + tmp) // 2 if tmp != 0 else 1 | |
self._state += 1 | |
self._state %= self.p | |
return out | |
# fmt: off | |
ps = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193] | |
# fmt: on | |
p = prod(ps) | |
reclen = 3 * 5 * 7 | |
known_seq = [int(gcd(i, reclen) > 1) for i in range(0, reclen)] | |
def compute_ks_known_bits(n, a_mod_rec): | |
ks_known_bits = [["?"] * 256 for _ in range(n)] | |
ctr = a_mod_rec | |
for i in range(n): | |
ksb = ks_known_bits[i] | |
for j in range(256): | |
if known_seq[ctr] == 1: | |
ksb[j] = "1" | |
ctr += 1 | |
if ctr == reclen: | |
ctr = 0 | |
ks_known_bits[i] = "".join(reversed(ksb)) # lsb first | |
return ks_known_bits | |
def compute_ks_info(ks_known_bits, min_chunk_size): | |
def safe_str_index(s, sub, start=0): | |
try: | |
return s.index(sub, start) | |
except ValueError: | |
return None | |
ks_info = [] | |
for ksb in ks_known_bits: | |
info = [] | |
i = ksb.index("1") | |
while True: | |
ni = safe_str_index(ksb, "?", i) | |
if ni is None: | |
break | |
if ni - i >= min_chunk_size: | |
info.append((i, ni - i, int(ksb[i:ni], 2))) # i, length, value | |
i = safe_str_index(ksb, "1", ni) | |
if i is None: | |
break | |
if i is not None and i < len(ksb) and len(ksb) - i >= min_chunk_size: | |
info.append((i, len(ksb) - i, int(ksb[i:], 2))) | |
ks_info.append(info) | |
return ks_info | |
q = secp256k1.q | |
PR = PolynomialRing(Zmod(q), "w", 400) | |
def attack(ms, rs, ss, guess_a_mod_rec): | |
n = len(ms) | |
ks_known_bits = compute_ks_known_bits(n, guess_a_mod_rec) | |
ks_info = compute_ks_info(ks_known_bits, 1) | |
# keep only prefix and suffix to keep the problem in HNP (not EHNP) | |
ks_info = [ | |
[(i, l, v) for i, l, v in info if i == 0 or (i + l) == 256] for info in ks_info | |
] | |
# and keep only the ones with at least 6 bits | |
to_keep = [sum(l for _, l, _ in info) >= 6 for info in ks_info] | |
ms = [m for m, keep in zip(ms, to_keep, strict=True) if keep] | |
rs = [r for r, keep in zip(rs, to_keep, strict=True) if keep] | |
ss = [s for s, keep in zip(ss, to_keep, strict=True) if keep] | |
ks_info = [info for info, keep in zip(ks_info, to_keep, strict=True) if keep] | |
syms_pool = list(PR.gens()) | |
bounds = {} | |
ks_sym = [] | |
for info in ks_info: | |
k_sym = 0 | |
ctr = 0 | |
for i, l, v in info: | |
if i > ctr: | |
sm = syms_pool.pop() | |
bounds[sm] = 2 ** (i - ctr) | |
k_sym += sm * 2**ctr | |
ctr = i | |
k_sym += v * 2**i | |
ctr += l | |
if ctr < 256: | |
sm = syms_pool.pop() | |
bounds[sm] = 2 ** (256 - ctr) | |
k_sym += sm * 2**ctr | |
ks_sym.append(k_sym) | |
d = syms_pool.pop() | |
eqs = [] | |
for m, r, s, k_sym in zip(ms, rs, ss, ks_sym): | |
eq = s * k_sym - (m + d * r) | |
eqs.append(eq) | |
bounds[d] = q | |
M, monos = polynomials_to_matrix(eqs) | |
assert monos[-1] == 1 | |
A = M[:, :-1].dense_matrix() | |
b = -vector(M[:, -1]) | |
lb = [0] * len(monos[:-1]) | |
ub = [bounds[m] for m in monos[:-1]] | |
# now we want to find x that A*x=b and lb<=x<=ub | |
# first construct the solution space: | |
s0 = A.solve_right(b) | |
ker = A.right_kernel_matrix().echelon_form() | |
# find lb<=s0+?*ker<=ub (mod q) | |
# -> lb-s0<=?*ker<=ub-s0 (mod q) | |
# HNP-like lattice | |
# [1,?] | |
# [0,q] | |
L = ker.change_ring(ZZ).stack(matrix.identity(ker.ncols())[ker.nrows() :] * q) | |
lbx = (vector(lb) - s0.change_ring(ZZ)).list() | |
ubx = (vector(ub) - s0.change_ring(ZZ)).list() | |
sol = solve_inequality( | |
L, | |
lbx, | |
ubx, | |
cvp=partial(kannan_cvp, reduction=LLL), | |
) | |
if all([a <= x <= b for a, x, b in zip(lbx, sol, ubx)]): | |
bound_sol = s0 + ZZ(sol[0]) * ker[0] | |
assert A * bound_sol == b | |
return int(bound_sol[0]) | |
# io = process(["python3", "server2.py"]) | |
io = remote("seqr.seccon.games", 13337) | |
io.sendlineafter(b">", f"{p:x}".encode()) | |
io.sendlineafter(b">", b"2") | |
io.recvuntil(b"pubkey = ") | |
pk_bytes = bytes.fromhex(io.recvline().strip().decode()) | |
Px = int.from_bytes(pk_bytes[:32], "big") | |
Py = int.from_bytes(pk_bytes[32:], "big") | |
pk = Point(Px, Py, curve=secp256k1) | |
print("pk") | |
print(pk) | |
io.sendlineafter(b">", b"3") | |
io.recvuntil(b"enc = ") | |
flag_enc = bytes.fromhex(io.recvline().strip().decode()) | |
def sign(msg: bytes): | |
io.sendlineafter(b">", b"1") | |
io.sendlineafter(b">", msg.hex().encode()) | |
io.recvuntil(b"signature = ") | |
sig = bytes.fromhex(io.recvline().strip().decode()) | |
z = int.from_bytes(sha1(msg).digest(), "big") | |
r = int.from_bytes(sig[:32], "big") | |
s = int.from_bytes(sig[32:], "big") | |
return z, r, s | |
def sign_batch(msgs: list[bytes]): | |
for msg in msgs: | |
io.sendline(b"1") | |
io.sendline(msg.hex().encode()) | |
for msg in msgs: | |
io.recvuntil(b"signature = ") | |
sig = bytes.fromhex(io.recvline().strip().decode()) | |
z = int.from_bytes(sha1(msg).digest(), "big") | |
r = int.from_bytes(sig[:32], "big") | |
s = int.from_bytes(sig[32:], "big") | |
yield z, r, s | |
n = 3000 | |
msgs = [str(i).encode() for i in range(n)] | |
# sigs = [sign(msg) for msg in tqdm(msgs)] | |
sigs = list(tqdm(sign_batch(msgs))) | |
ms = [z for z, r, s in sigs] | |
rs = [r for z, r, s in sigs] | |
ss = [s for z, r, s in sigs] | |
print("got signatures") | |
def get_mod(bits, p): | |
# check if bits[i::p] is all 1 | |
# then return -i % p | |
for i in range(p): | |
while i < len(bits): | |
if bits[i] != 1: | |
break | |
i += p | |
else: | |
return -i % p | |
return -1 | |
with ProcessPoolExecutor(max_workers=8) as executor: | |
start = time.time() | |
futures = [] | |
for i in range(0, reclen): | |
futures.append((i, executor.submit(attack, ms, rs, ss, i))) | |
for i, future in futures: | |
sk = future.result() | |
print("brute", i, sk) | |
print("current", time.time() - start) | |
if sk and sk * secp256k1.G == pk: | |
print("FOUND SK !!") | |
break | |
for _, future in futures[i + 1 :]: | |
future.cancel() | |
# k=(z+rd)/s | |
ks = [(z + r * sk) * pow(s, -1, q) % q for z, r, s in sigs] | |
bits = list(map(int, "".join(f"{k:0256b}" for k in ks))) | |
rems = [get_mod(bits, p) for p in ps] | |
a_rec = int(crt(rems, ps)) | |
print("a_rec", a_rec) | |
key = (sk ^ a_rec).to_bytes(32, "big") | |
flag = AES.new(key, AES.MODE_ECB).decrypt(flag_enc) | |
print("flag", flag) | |
# SECCON{17_15_un4cc3p74bl3_7h47_l363ndr3_c4n_u53_c0mp05173_numb3r5} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment