The challenge encrypts a very long plaintext with a short (30 byte) key by simply XOR-ing the key with the plaintext. The key thing to note here is that the plaintext isn't random, and we can use this to recover the key.
import string
from secrets import choice, os, randbelow
from itertools import cycle
flag = "flag{XXXXXXXXXXXXXXXXXXXXXXX}" # I've replaced the flag chars with `X`
allowed_chars = string.printable
plaintext = [choice(allowed_chars) for _ in range(10000)]
pos = randbelow(len(plaintext) - len(flag))
plaintext[pos:pos+len(flag)] = [*flag]
plaintext = "".join(plaintext).encode()
key = os.urandom(30)
ciphertext = bytes([x^y for x,y in zip(cycle(key), plaintext)])
open("enc", "wb").write(ciphertext)
Because the key is so tiny, the same byte of the key
Now, we know that all of the string.printable
. It turns out that for each byte of the key, there can only be
ciphertext = open("dist/enc", "rb").read()
import string
from functools import reduce
from itertools import cycle
import re
allowed_chars = set((string.printable).encode())
klen = 30
# Maps each ciphertext byte `ct` to the set of possible key
# bytes that could have encrypted a plaintext byte into `ct`
keydict = {
ct : set([c^ct for c in allowed_chars])
for ct in range(0x100)
}
# Get the possible values of the bytes of the key
# by correlating across the whole ciphertext given
keypos = [
reduce(
lambda x,y: x & y, # Set intersection
(keydict[c] for c in ciphertext[kidx::klen])
)
for kidx in range(klen)
]
# Check that we have a unique key
assert all(len(x)==1 for x in keypos)
# Recover the key!
key = [next(iter(k)) for k in keypos]
# Decrypt the plaintext and read the flag
plaintext = bytes([x^y for x,y in zip(cycle(key), ciphertext)]).decode()
flag = re.findall(r"flag{[^}]+}", plaintext)[0]
print(flag)
The server allows the player to input a bunch of data (userinput
), and the server will create a string data=<userinput>,flag=<flag>
and encrypt it with AES-CBC. The IV used per encryption is predictable: The next IV is the increment of the current one. The player is expected to use userinput
to leak the flag
, which is the suffix of the string.
from Crypto.Cipher import AES
from Crypto.Util.number import bytes_to_long, long_to_bytes
from Crypto.Util.Padding import pad
import os
allowed_chars = b'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
flag = b"flag{XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX}" # Omitted
assert len(set(flag) - set(allowed_chars)) == 0
key = os.urandom(16)
iv = bytes_to_long(os.urandom(16))
def encrypt_data(data:str):
global iv; iv = (iv + 1) & ((1<<128) - 1)
data = b"data=" + bytes.fromhex(data) + b",flag=" + flag
data = pad(data, 16)
biv = long_to_bytes(iv)
cipher = AES.new(key, AES.MODE_CBC, iv=biv)
return (biv + cipher.encrypt(data)).hex()
while True:
data = input("Input (hex): ")
print("Output (hex): ", encrypt_data(data))
Let's first simplify the problem by assuming we are using AES-ECB instead of CBC. Let's say we want to leak the first byte of the flag. We can do so by making two requests to the server:
Request 1: You send AAAAAAAAAAAAAAA,flag=flag{<guess>
, where <guess>
is a guess of the first character of the flag. This results in the server generating the following plaintext (split into 16 bytes per line):
"data=AAAAAAAAAAA" +
"AAAA,flag=flag{" + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"
Request 2: You send AAAAAAAAAAAAAAA
, resulting in:
"data=AAAAAAAAAAA" +
"AAAA,flag=flag{?" + # Actual flag!
"????????????????"
If <guess>
is correct, the resulting ciphertexts' 2nd block will be the same, allowing us to recover the first byte of the flag. Lets say we've gotten that the first byte is d
. To recover the 2nd byte, we can shift everything back by 1, and send the following two requests:
Request 1: You send AAAAAAAAAAAAAA,flag=flag{d<guess>
"data=AAAAAAAAAAA" +
"AAA,flag=flag{d" + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"
Request 2: You send AAAAAAAAAAAAAA
, resulting in:
"data=AAAAAAAAAAA" +
"AAA,flag=flag{d?" + # Actual flag!
"????????????????"
We can add more A
s so that we always have sufficient space to "shift back" to recover the next byte. For instance, by adding 16 more A
s to each request, we might recover the 6th flag character like:
Request 1: You send AAAAAAAAAAAAAAAAAAAAAAAAAA,flag=flag{d<guess>
"data=AAAAAAAAAAA" +
"AAAAAAAAAAAAAAA," +
"flag=flag{dont_" + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"
Request 2: You send AAAAAAAAAAAAAAAAAAAAAAAAAA
, resulting in:
"data=AAAAAAAAAAA" +
"AAAAAAAAAAAAAAA," +
"flag=flag{dont_?" + # Actual flag!
"????????????????"
Note that now we need to compare the 3rd block of both ciphertexts, not the 2nd.
Unfortunately, this strategy can't work without modification for the actual challenge, which uses AES-CBC. This is because each encryption instance uses a different iv
. Since the iv
gets XOR-ed with the first plaintext-block before being encrypted, resulting in a different first ciphertext block, which gets fed into the subsequent plaintext block and so on, all ciphertext blocks will differ despite two plaintext blocks being equal.
However, the main flaw of this challenge is that the change in iv
is predictable. This means we can "negate" the change in the iv
between the above ciphertext pairs by changing one of the plaintext's first block, and allow the above strategy to work as per normal.
Explicitly, suppose we're trying to guess the first flag byte. We send the plaintext AAAAAAAAAAAAAAA,flag=flag{<guess>
and collect the iv
.
"data=AAAAAAAAAAA" +
"AAAA,flag=flag{" + guess_char,
",flag=flag{?????" + # Actual flag!
"????????????????"
We then predict the next iv
, say iv'
. We then compute diff = iv ^ iv'
. We want to send a plaintext such that the server ends up encrypting:
xor("data=AAAAAAAAAAA", diff) +
"AAAA,flag=flag{?" + # Actual flag!
"????????????????"
This makes it such that if <guess>
is correct, the 2nd block of both ciphertext is still equal. However, we can't control the prefix data=
. Luckily, since iv
and iv'
is extremely likely to only differ in its lower bits, xor("data=AAAAAAAAAAA", diff)
is likely to have the same prefix data=
. Hence, we send
xor("data=AAAAAAAAAAA", diff)[len("data="):] + "AAAA"
as our second plaintext.
from nclib import Netcat
from functools import reduce
from Crypto.Util.number import long_to_bytes, bytes_to_long
nc = Netcat(("challs.nusgreyhats.org", 55001))
nqueries = 0
def encrypt_data(data:str) -> str:
global nqueries; nqueries += 1
nc.recv_until(b": ")
nc.sendline(data.encode())
nc.recv_until(b": ")
ct = nc.recv_line().decode().strip()
return ct
allowed_chars = b'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
str1 = b"data="
str2 = b",flag="
# xor(xs,ys,zs,...) XORs any number of bytearrays together.
xor = lambda *x: reduce(lambda a,b: bytes(x^y for x,y in zip(a,b)), x)
rec_flag = b""
while b"}" not in rec_flag:
data = (
b"A"*(16 - len(str1))
+ b"A"*(160-len(str2)-len(rec_flag)-1)
)
ct = encrypt_data(data.hex())
biv = bytes_to_long(bytes.fromhex(ct[:32]))
blk = ct[352:384]
for i,g in enumerate(allowed_chars):
g = bytes([g])
data = (
xor(
b"A"*16,
long_to_bytes(biv),
long_to_bytes(biv + i + 1)
)[-16 + len(str1):]
+ b"A"*(160-len(str2)-len(rec_flag)-1)
+ str2 + rec_flag + g
)
ct = encrypt_data(data.hex())
if blk == ct[352:384]:
break
else: raise Exception()
rec_flag += g
print(rec_flag.decode())
print("\n\n========= FINAL FLAG =========")
print(rec_flag.decode())
print("========= FINAL FLAG =========")
print("Number of queries:", nqueries)
The server generates an RSA public key True
if the padding format of the plaintext is correct, and False
otherwise. The correct padding format only requires the first byte of the plaintext to be \x00
.
The player is then expect to leak the flag True
and False
flags the server returns. I.e., this is a padding oracle.
server.py
from Crypto.Util.number import long_to_bytes, bytes_to_long, getPrime
from secrets import randbelow
def pad(msg:bytes):
assert len(msg) < 128
if len(msg) == 127:
return b"\0" + msg
padding_len = 128 - len(msg) - 2
padding = bytes([randbelow(254) + 1 for _ in range(padding_len)])
return b"\0" + padding + b"\0" + msg
def unpad(msg:bytes):
assert msg[:1] == b"\0"
if b"\0" in msg[1:]:
return msg[2+msg[1:].index(b"\0"):]
return msg[1:]
flag = b"flag{XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX}" # omitted
m = bytes_to_long(pad(flag))
p,q = getPrime(512), getPrime(512)
n = p*q
e = 0x10001
d = pow(e, -1, (p-1)*(q-1))
encrypted_flag = pow(m, e, n)
print("encrypted_flag =", encrypted_flag)
print("n =", n)
def decrypt(c):
try:
unpad(long_to_bytes(pow(c, d, n), 128))
return True
except:
return False
while True:
data = int(input("Input (int): "))
output = decrypt(data)
print("Output (bool): ", output, flush=True)
Let encrypted_flag
. We are going to perform a Binary Search to compute the value of
First, let's define a function is_error(x:int)->bool
. is_error(x)
will ask the server to decrypt True
if the server encounters an invalid padding. Recall that in RSA, using \x00
. I.e., if the plaintext is smaller than some fixed threshold, then it passes the padding check.
The hint to this challenge will be the answers to the following questions:
- What do you expect
is_error(1)
to return? - What do you expect
is_error(x)
to return whenx
is a small integer close to1
? - As you continue to increase the value of
x
, you'd expectis_error(x)
to start returningTrue
. Eventually, asx
continues to increase, it will start returningFalse
again! (Why?) Let$x_0$ be the value ofx
where this transition occurs, i.e.,is_error(x0) == True
andis_error(x0 + 1) == False
. What can you infer about$t$ from$x_0$ ? (Answer:$t \in (x_0, x_0 + 1]$ ) - Now let's make the hypothesis that
$t \in (x_0, x_0 + 0.5]$ . What would you expectis_error(2*x0 + 1)
to return? What if$t \in (x_0 + 0.5, x_0 + 1]$ ? Once you've answered this question, you've half-ed the possibilities$t$ can be. This is the start of the Binary Search. - Suppose from the above question, you've inferred that
$t \in (x_0, x_0 + 0.5]$ . How would you continue the binary search? I.e., how would you tell if$t \in (x_0, x_0 + 0.25]$ or if$t \in (x_0 + 0.25, x_0 + 0.5]$ ? - When do you stop the binary search?
If you're stuck on these questions, no joke, model drawing from primary school can help you visualise the size of each variable.
Because of the padding checking only if the first byte is \x00
, the server will error only if
What do you expect is_error(1)
to return?
Answer: False
, because we are simply decrypting the given ciphertext of the flag.
What do you expect is_error(x)
to return when x
is a small integer close to 1
?
Answer: False
, as long as
As you continue to increase the value of x
, you'd expect is_error(x)
to start returning True
. Eventually, as x
continues to increase, it will start returning False
again! (Why?) Let x
where this transition occurs, i.e., is_error(x0) == True
and is_error(x0 + 1) == False
. What can you infer about
Answer:
The server returns True
the moment
again, so the server returns False
again.
Let
So, letting
Now let's make the hypothesis that is_error(2*x0 + 1)
to return? What if
Answer:
If
Further, since
Therefore,
Therefore, is_error(2*x0 + 1) == False
if
Similarly, suppose
This implies
Therefore, is_error(2*x0 + 1) == True
if
Suppose from the above question, you've inferred that
Answer:
Apply the same logic as above to
- If
$t \in (x_0, x_0 + 0.25]$ , thenis_error(4*x0 + 1) == False
- If
$t \in (x_0 + 0.25, x_0 + 0.5]$ , thenis_error(4*x0 + 1) == True
When do you stop the binary search?
Answer:
Let
from Crypto.Util.number import long_to_bytes
from nclib import Netcat
nc = Netcat(("challs.nusgreyhats.org", 55002))
e = 0x10001
encrypted_flag = int(nc.recvline().split(b" = ")[1].strip().decode())
n = int(nc.recvline().split(b" = ")[1].strip().decode())
def unpad(msg:bytes):
assert msg[:1] == b"\0"
if b"\0" in msg[1:]:
return msg[2+msg[1:].index(b"\0"):]
return msg[1:]
nqueries = 0
def decrypt(data:int) -> int:
global nqueries; nqueries += 1
print(f"nqueries = {nqueries} \r", end="")
nc.recv_until(b": ")
nc.sendline(str(data).encode())
nc.recv_until(b": ")
ct = nc.recv_line().decode().strip() == "True"
return ct
def is_error(k):
return not decrypt((pow(k,e,n) * encrypted_flag) % n)
# I can start at t = 0,
# But i can save some queries by starting from
# n//(1<<(1024-8)) instead, since it is a
# lower bound for t. (as m < 2^(1024-8))
t = n//(1<<(1024-8))
while not is_error(t+1): t += 1
while is_error(t+1): t += 1
for i in range(1024):
upper = (n << i) // t
lower = (n << i) // (t+1) + 1
if upper==lower: break
t = 2*t + is_error(2*t+1)
print("Flag =", unpad(long_to_bytes(lower, 128)).decode())
print("Queries =", nqueries)