Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save JuliaPoo/7c1bf68cdc281f64241efbc8750582ee to your computer and use it in GitHub Desktop.
Save JuliaPoo/7c1bf68cdc281f64241efbc8750582ee to your computer and use it in GitHub Desktop.
Hackbash 2024 Finals Solutions - Jules

Xor Key

The challenge encrypts a very long plaintext with a short (30 byte) key by simply XOR-ing the key with the plaintext. The key thing to note here is that the plaintext isn't random, and we can use this to recover the key.

Challenge Files

import string
from secrets import choice, os, randbelow
from itertools import cycle

flag = "flag{XXXXXXXXXXXXXXXXXXXXXXX}" # I've replaced the flag chars with `X`

allowed_chars = string.printable
plaintext = [choice(allowed_chars) for _ in range(10000)]
pos = randbelow(len(plaintext) - len(flag))
plaintext[pos:pos+len(flag)] = [*flag]
plaintext = "".join(plaintext).encode()

key = os.urandom(30)
ciphertext = bytes([x^y for x,y in zip(cycle(key), plaintext)])
open("enc", "wb").write(ciphertext)

Solution

Because the key is so tiny, the same byte of the key $k$ is XOR-ed with many different bytes of the plaintext $p_1, ..., p_m$ to create the bytes of the ciphertext $c_1, ..., c_m$. For each byte of the key (let us denote the byte in question to be $k$), we can guess its value (let our guess be $k'$) and compute our prediction of $p_j$ (let the prediction be $p_j'$) from $c_j$:

$$ p_j' = c_j \oplus k', \quad 1 \le j \le m $$

Now, we know that all of the $p_j'$ must be inside string.printable. It turns out that for each byte of the key, there can only be $1$ possible value. Hence we can recover the whole key, decrypt the plaintext and read the flag.

Implementation

ciphertext = open("dist/enc", "rb").read()

import string
from functools import reduce
from itertools import cycle
import re

allowed_chars = set((string.printable).encode())
klen = 30

# Maps each ciphertext byte `ct` to the set of possible key 
# bytes that could have encrypted a plaintext byte into `ct`
keydict = {
    ct : set([c^ct for c in allowed_chars]) 
    for ct in range(0x100)
}

# Get the possible values of the bytes of the key
# by correlating across the whole ciphertext given
keypos = [
    reduce(
        lambda x,y: x & y, # Set intersection
        (keydict[c] for c in ciphertext[kidx::klen])
    ) 
    for kidx in range(klen)
]

# Check that we have a unique key
assert all(len(x)==1 for x in keypos)

# Recover the key!
key = [next(iter(k)) for k in keypos]

# Decrypt the plaintext and read the flag
plaintext = bytes([x^y for x,y in zip(cycle(key), ciphertext)]).decode()
flag = re.findall(r"flag{[^}]+}", plaintext)[0]
print(flag)

AES-CBC

The server allows the player to input a bunch of data (userinput), and the server will create a string data=<userinput>,flag=<flag> and encrypt it with AES-CBC. The IV used per encryption is predictable: The next IV is the increment of the current one. The player is expected to use userinput to leak the flag, which is the suffix of the string.

Challenge Files

from Crypto.Cipher import AES
from Crypto.Util.number import bytes_to_long, long_to_bytes
from Crypto.Util.Padding import pad
import os

allowed_chars = b'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '

flag = b"flag{XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX}" # Omitted
assert len(set(flag) - set(allowed_chars)) == 0
key = os.urandom(16)
iv = bytes_to_long(os.urandom(16))

def encrypt_data(data:str):
    global iv; iv = (iv + 1) & ((1<<128) - 1)
    data = b"data=" + bytes.fromhex(data) + b",flag=" + flag
    data = pad(data, 16)
    biv = long_to_bytes(iv)
    cipher = AES.new(key, AES.MODE_CBC, iv=biv)
    return (biv + cipher.encrypt(data)).hex()

while True:
    data = input("Input (hex): ")
    print("Output (hex): ", encrypt_data(data))

Solution

Let's first simplify the problem by assuming we are using AES-ECB instead of CBC. Let's say we want to leak the first byte of the flag. We can do so by making two requests to the server:

Request 1: You send AAAAAAAAAAAAAAA,flag=flag{<guess>, where <guess> is a guess of the first character of the flag. This results in the server generating the following plaintext (split into 16 bytes per line):

"data=AAAAAAAAAAA" +
"AAAA,flag=flag{"  + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"

Request 2: You send AAAAAAAAAAAAAAA, resulting in:

"data=AAAAAAAAAAA" +
"AAAA,flag=flag{?" + # Actual flag!
"????????????????"

If <guess> is correct, the resulting ciphertexts' 2nd block will be the same, allowing us to recover the first byte of the flag. Lets say we've gotten that the first byte is d. To recover the 2nd byte, we can shift everything back by 1, and send the following two requests:

Request 1: You send AAAAAAAAAAAAAA,flag=flag{d<guess>

"data=AAAAAAAAAAA" +
"AAA,flag=flag{d"  + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"

Request 2: You send AAAAAAAAAAAAAA, resulting in:

"data=AAAAAAAAAAA" +
"AAA,flag=flag{d?" + # Actual flag!
"????????????????"

We can add more As so that we always have sufficient space to "shift back" to recover the next byte. For instance, by adding 16 more As to each request, we might recover the 6th flag character like:

Request 1: You send AAAAAAAAAAAAAAAAAAAAAAAAAA,flag=flag{d<guess>

"data=AAAAAAAAAAA" +
"AAAAAAAAAAAAAAA," +
"flag=flag{dont_"  + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"

Request 2: You send AAAAAAAAAAAAAAAAAAAAAAAAAA, resulting in:

"data=AAAAAAAAAAA" +
"AAAAAAAAAAAAAAA," +
"flag=flag{dont_?" + # Actual flag!
"????????????????"

Note that now we need to compare the 3rd block of both ciphertexts, not the 2nd.

Unfortunately, this strategy can't work without modification for the actual challenge, which uses AES-CBC. This is because each encryption instance uses a different iv. Since the iv gets XOR-ed with the first plaintext-block before being encrypted, resulting in a different first ciphertext block, which gets fed into the subsequent plaintext block and so on, all ciphertext blocks will differ despite two plaintext blocks being equal.

However, the main flaw of this challenge is that the change in iv is predictable. This means we can "negate" the change in the iv between the above ciphertext pairs by changing one of the plaintext's first block, and allow the above strategy to work as per normal.

Explicitly, suppose we're trying to guess the first flag byte. We send the plaintext AAAAAAAAAAAAAAA,flag=flag{<guess> and collect the iv.

"data=AAAAAAAAAAA" +
"AAAA,flag=flag{"  + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"

We then predict the next iv, say iv'. We then compute diff = iv ^ iv'. We want to send a plaintext such that the server ends up encrypting:

xor("data=AAAAAAAAAAA", diff) +
"AAAA,flag=flag{?" + # Actual flag!
"????????????????"

This makes it such that if <guess> is correct, the 2nd block of both ciphertext is still equal. However, we can't control the prefix data=. Luckily, since iv and iv' is extremely likely to only differ in its lower bits, xor("data=AAAAAAAAAAA", diff) is likely to have the same prefix data=. Hence, we send

xor("data=AAAAAAAAAAA", diff)[len("data="):] + "AAAA"

as our second plaintext.

Implementation

from nclib import Netcat
from functools import reduce
from Crypto.Util.number import long_to_bytes, bytes_to_long

nc = Netcat(("challs.nusgreyhats.org", 55001))
nqueries = 0
def encrypt_data(data:str) -> str:
    global nqueries; nqueries += 1
    nc.recv_until(b": ")
    nc.sendline(data.encode())
    nc.recv_until(b": ")
    ct = nc.recv_line().decode().strip()
    return ct

allowed_chars = b'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
str1 = b"data="
str2 = b",flag="
# xor(xs,ys,zs,...) XORs any number of bytearrays together.
xor = lambda *x: reduce(lambda a,b: bytes(x^y for x,y in zip(a,b)), x)

rec_flag = b""
while b"}" not in rec_flag:
    
    data = (
        b"A"*(16 - len(str1))
        + b"A"*(160-len(str2)-len(rec_flag)-1)
    )
    ct = encrypt_data(data.hex())
    biv = bytes_to_long(bytes.fromhex(ct[:32]))
    blk = ct[352:384]
    
    for i,g in enumerate(allowed_chars):
        g = bytes([g])
        data = (
            xor(
                b"A"*16,
                long_to_bytes(biv), 
                long_to_bytes(biv + i + 1)
            )[-16 + len(str1):]
            + b"A"*(160-len(str2)-len(rec_flag)-1)
            + str2 + rec_flag + g
        )
        ct = encrypt_data(data.hex())
        if blk == ct[352:384]:
            break
    else: raise Exception()
    rec_flag += g
    print(rec_flag.decode())

print("\n\n========= FINAL FLAG =========")
print(rec_flag.decode())
print("========= FINAL FLAG =========")
print("Number of queries:", nqueries)

Baby RSA Bleichenbacher's Attack

The server generates an RSA public key $(n, e)$, encrypts the flag $m$ into the ciphertext $c$ and gives them to the player. The player can then send any ciphertext to the server. The server would decrypt and return True if the padding format of the plaintext is correct, and False otherwise. The correct padding format only requires the first byte of the plaintext to be \x00.

The player is then expect to leak the flag $m$ from the ciphertext $c$ via the True and False flags the server returns. I.e., this is a padding oracle.

Challenge Files

server.py

from Crypto.Util.number import long_to_bytes, bytes_to_long, getPrime
from secrets import randbelow

def pad(msg:bytes):
    assert len(msg) < 128
    if len(msg) == 127:
        return b"\0" + msg
    padding_len = 128 - len(msg) - 2
    padding = bytes([randbelow(254) + 1 for _ in range(padding_len)])
    return b"\0" + padding + b"\0" + msg

def unpad(msg:bytes):
    assert msg[:1] == b"\0"
    if b"\0" in msg[1:]:
        return msg[2+msg[1:].index(b"\0"):]
    return msg[1:]

flag = b"flag{XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX}" # omitted
m = bytes_to_long(pad(flag))

p,q = getPrime(512), getPrime(512)
n = p*q
e = 0x10001
d = pow(e, -1, (p-1)*(q-1))
encrypted_flag = pow(m, e, n)
print("encrypted_flag =", encrypted_flag)
print("n =", n)

def decrypt(c):
    try:
        unpad(long_to_bytes(pow(c, d, n), 128))
        return True
    except:
        return False

while True:
    data = int(input("Input (int): "))
    output = decrypt(data)
    print("Output (bool): ", output, flush=True)

Hint

Let $m$ be the message (flag), $n$ be the public modulus and $c$ be the encrypted_flag. We are going to perform a Binary Search to compute the value of $t = \frac{n}{m}$ to sufficient precision to recover $m$.

First, let's define a function is_error(x:int)->bool. is_error(x) will ask the server to decrypt $c x^e \text{ mod } n$, and return True if the server encounters an invalid padding. Recall that in RSA, using $c x^e \text{ mod } n$ as the ciphertext will decrypt into the plaintext $m x \text{ mod } n$. Note that the padding check only checks if the first byte of the plaintext is \x00. I.e., if the plaintext is smaller than some fixed threshold, then it passes the padding check.

The hint to this challenge will be the answers to the following questions:

  • What do you expect is_error(1) to return?
  • What do you expect is_error(x) to return when x is a small integer close to 1?
  • As you continue to increase the value of x, you'd expect is_error(x) to start returning True. Eventually, as x continues to increase, it will start returning False again! (Why?) Let $x_0$ be the value of x where this transition occurs, i.e., is_error(x0) == True and is_error(x0 + 1) == False. What can you infer about $t$ from $x_0$? (Answer: $t \in (x_0, x_0 + 1]$)
  • Now let's make the hypothesis that $t \in (x_0, x_0 + 0.5]$. What would you expect is_error(2*x0 + 1) to return? What if $t \in (x_0 + 0.5, x_0 + 1]$? Once you've answered this question, you've half-ed the possibilities $t$ can be. This is the start of the Binary Search.
  • Suppose from the above question, you've inferred that $t \in (x_0, x_0 + 0.5]$. How would you continue the binary search? I.e., how would you tell if $t \in (x_0, x_0 + 0.25]$ or if $t \in (x_0 + 0.25, x_0 + 0.5]$?
  • When do you stop the binary search?

If you're stuck on these questions, no joke, model drawing from primary school can help you visualise the size of each variable.

Solution

Answers to the Hints:

Because of the padding checking only if the first byte is \x00, the server will error only if $(m x \text{ mod } n) \ge 2^{127 \times 8}$. Let $L = 2^{127 \times 8}$.


What do you expect is_error(1) to return?

Answer: False, because we are simply decrypting the given ciphertext of the flag.


What do you expect is_error(x) to return when x is a small integer close to 1?

Answer: False, as long as $m x &lt; L$.


As you continue to increase the value of x, you'd expect is_error(x) to start returning True. Eventually, as x continues to increase, it will start returning False again! (Why?) Let $x_0$ be the value of x where this transition occurs, i.e., is_error(x0) == True and is_error(x0 + 1) == False. What can you infer about $t$ from $x_0$? (Answer: $t \in (x_0, x_0 + 1]$)

Answer:

The server returns True the moment $m x \ge L$, but eventually $mx \ge n$. At that point

$$ (mx \text{ mod } n) = m x - n < L $$

again, so the server returns False again.

Let $x_0$ be the point that happens. This means that $m x_0 &lt; n$ and $m (x_0 + 1) \ge n$.

So, letting $t = \frac{n}{m}$,

$$ x_0 < t \le x_0 + 1 \implies t \in (x_0, x_0 + 1] $$


Now let's make the hypothesis that $t \in (x_0, x_0 + 0.5]$. What would you expect is_error(2*x0 + 1) to return? What if $t \in (x_0 + 0.5, x_0 + 1]$? Once you've answered this question, you've half-ed the possibilities $t$ can be. This is the start of the Binary Search.

Answer:

If $t \in (x_0, x_0 + 0.5]$, then since $t \le x_0 + 0.5$,

$$ \begin{aligned} (x_0 + 0.5) m &\ge t m = n \\ (2 x_0 + 1) m &\ge 2n \\ (2 x_0 + 1) m - 2 n &\ge 0 \end{aligned} $$

Further, since $t &gt; x_0$,

$$ \begin{aligned} x_0 m &< tm = n \\ (2 x_0 + 1) m &< 2n + m \\ (2 x_0 + 1) m - 2 n &< m &< L < n \end{aligned} $$

Therefore,

$$ \begin{aligned} &0 \le (2 x_0 + 1) m - 2 n < L < n \\ &\implies ((2 x_0 + 1) m \text{ mod }n) = (2 x_0 + 1) m - 2 n < L \end{aligned} $$

Therefore, is_error(2*x0 + 1) == False if $t \in (x_0, x_0 + 0.5]$.

Similarly, suppose $t \in (x_0 + 0.5, x_0 + 1]$,

$$ \begin{aligned} (x_0 + 0.5) m &\le tm = n &&\le (x_0 + 1) m \\ (2x_0 + 1) m &\le 2 n &&\le (2x_0 + 2) m \\ (2x_0 + 1) m - n &\le n, \text{ and } && L < n - m \le (2x_0 + 1) m - n \end{aligned} $$

This implies

$$ \begin{aligned} &L < (2x_0 + 1) m - n \le n \\ & \implies ((2 x_0 + 1) m \text{ mod } n) = (2 x_0 + 1) m - n > L \end{aligned} $$

Therefore, is_error(2*x0 + 1) == True if $t \in (x_0 + 0.5, x_0 + 1]$.


Suppose from the above question, you've inferred that $t \in (x_0, x_0 + 0.5]$. How would you continue the binary search? I.e., how would you tell if $t \in (x_0, x_0 + 0.25]$ or if $t \in (x_0 + 0.25, x_0 + 0.5]$?

Answer:

Apply the same logic as above to $m \leftarrow \frac{m}{2}$, resulting in $t \leftarrow 2 t$. Then one can see that

  • If $t \in (x_0, x_0 + 0.25]$, then is_error(4*x0 + 1) == False
  • If $t \in (x_0 + 0.25, x_0 + 0.5]$, then is_error(4*x0 + 1) == True

When do you stop the binary search?

Answer:

Let $t_{\text{upper}}$ and $t_{\text{lower}}$ be the upper and lower bounds of $t$ after some number of binary search iterations. Stop when

$$ \left \lceil \frac{n}{t_{\text{upper}}} \right \rfloor = \left \lceil \frac{n}{t_{\text{lower}}}\right\rfloor = m $$

Implementation

from Crypto.Util.number import long_to_bytes
from nclib import Netcat

nc = Netcat(("challs.nusgreyhats.org", 55002))
e = 0x10001
encrypted_flag = int(nc.recvline().split(b" = ")[1].strip().decode())
n = int(nc.recvline().split(b" = ")[1].strip().decode())

def unpad(msg:bytes):
    assert msg[:1] == b"\0"
    if b"\0" in msg[1:]:
        return msg[2+msg[1:].index(b"\0"):]
    return msg[1:]

nqueries = 0
def decrypt(data:int) -> int:
    global nqueries; nqueries += 1
    print(f"nqueries = {nqueries}   \r", end="")
    nc.recv_until(b": ")
    nc.sendline(str(data).encode())
    nc.recv_until(b": ")
    ct = nc.recv_line().decode().strip() == "True"
    return ct

def is_error(k):
    return not decrypt((pow(k,e,n) * encrypted_flag) % n)

# I can start at t = 0,
# But i can save some queries by starting from
# n//(1<<(1024-8)) instead, since it is a
# lower bound for t. (as m < 2^(1024-8))
t = n//(1<<(1024-8))
while not is_error(t+1): t += 1
while is_error(t+1): t += 1
for i in range(1024):
    upper = (n << i) // t
    lower = (n << i) // (t+1) + 1
    if upper==lower: break
    t = 2*t + is_error(2*t+1)
    
print("Flag =", unpad(long_to_bytes(lower, 128)).decode())
print("Queries =", nqueries)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment