Skip to content

Instantly share code, notes, and snippets.

@hellman
Last active October 8, 2019 06:02
Show Gist options
  • Save hellman/3faeb41275fb013407b503d69f332207 to your computer and use it in GitHub Desktop.
Save hellman/3faeb41275fb013407b503d69f332207 to your computer and use it in GitHub Desktop.
0CTF 2018 Quals - zeroC4 (Crypto 785)
#-*- coding:utf-8 -*-
"""
In this challenge we have a weakened version of RC4.
It operations on permutation of values 0..31.
Moreover, i is incremented in the beginning of the loop instead of the end.
We are given access to a related-key oracle.
We can send any key delta and the server will return us the generated sequence using the key xored with our delta.
There is a well known paper
"Weaknesses in the Key Scheduling Algorithm of RC4."
by Fluhrer, Mantin, Shamir.
In Section 8 they describe a Related Key attack.
And it actually works better if the key schedule is modified exactly as in the challenge.
The main idea is that we can recover the 16-byte key in layer of 16 bits, from LSBs of each key byte to MSBs.
If LSBits of the key bytes form a special pattern, then the LSBits of the output sequence correlate with a special sequence.
The script stat.py can be used to choose correlation bound for filtering wrong keys.
It is slightly difficult because we have only 1500 queries of 512 deltas, that is 2^19.5 deltas total.
We can recover 4 LSBits of each key byte and then bruteforce the 16 MSBits locally.
With a good probability we get the key.
The flag: flag{Haha~~Do_y0u_3nj0y_ouR_stre4m_c1pher?}
"""
import string
import random
from hashlib import sha256
from struct import pack
from itertools import product
from hashlib import sha256
from sock import Sock
from zer0C4 import ksa, prga, N, mask
def prga0(s, n):
"""weakened prga"""
i = j = 0
res = bytearray()
for _ in range(n):
i = (i + 1) & mask
j = (j + s[i]) & mask
# s[i], s[j] = s[j], s[i]
res.append(s[(s[i]+s[j])&mask])
return tuple(res)
def bconserves_t(b, s, t):
return s[t] % b == t % b
def bconserving(b, s):
num = 0
for t in xrange(len(s)):
num += bconserves_t(b, s, t)
return num == len(s)
def bexact(b, key):
for t in xrange(16):
if key[t] % b != (1 - t) % b:
return False
return True
def randvec():
return [random.randint(0, 31) for _ in xrange(16)]
f = Sock("202.120.7.220 1234")
# pow
# sha256(XXXX+JyZoJLkhS8Jhsoxi) == 6387c59e693c59880ff6458c048a089aac925573cd27d61dcce6e4049dac084d
# Give me XXXX:
alpha = string.ascii_letters + string.digits
suff, target = f.read_until_re(r"XXXX\+(\w+)\) == (\w{64})").groups()
print suff, target
for p in product(alpha, repeat=4):
p = "".join(p)
if sha256(p + suff).hexdigest() == target:
break
else:
print "FAIL"
quit()
print p
f.send(p)
print "ok"
lastdelta = [0] * 16
NQ = 0
NQALL = 0
def oracle_chunk(deltas):
global NQ, NQALL, lastdelta
res = []
NQ += 1
NQALL += 1
f.read_until("2. Guess the key.\n")
f.send("1\n")
s = ""
for delta in deltas:
tosave = delta[::]
delta = delta[::]
for i in xrange(16):
s += chr(delta[i] ^ lastdelta[i])
lastdelta = tosave
f.send(pack("H", len(s)))
f.send(s)
f.read_until("Here is you xor-key: ")
res = map(ord, f.read_nbytes(len(s)))
for c in res: assert 0 <= c < 32
return [res[i:i+16] for i in xrange(0, len(res), 16)]
def calc_score(stream):
score = 0
step = 0.2
wt = 1 + step * 16
nmatch = 0
for c0, c in zip(stream0, stream):
ok = ((c % b) == (c0 % b))
nmatch += ok
score += wt * ok
wt -= step
return score
bits16 = tuple(product(range(2), repeat=16))
def genchunk(offsets):
alldeltas = []
for offset in offsets:
bits = bits16[offset]
delta = randvec()
for i in xrange(16):
delta[i] -= delta[i] % b
delta[i] += known_delta[i] % (b/2)
delta[i] += bits[i] << (q - 1)
alldeltas.append(delta)
return alldeltas
def getavg(off):
return data[off][0] / float(data[off][1])
def add_scores():
all_offs = list(offs)
for i in xrange(0, len(all_offs), 512):
cur_offs = all_offs[i:i+512]
while len(cur_offs) + len(all_offs[i:i+512]) <= 512:
cur_offs += all_offs[i:i+512]
chunk = genchunk(cur_offs)
outs = oracle_chunk(chunk)
for off, out in zip(cur_offs, outs):
score, num = data[off]
score += calc_score(out)
num += 1
data[off] = score, num
# q = 1
# 25 : 13023 81.39375 % good
# 25 : 3992 24.95 % bad
# avg: 21.9 vs 29.6
# 20 : 15396 96.225 %
# 20 : 10481 65.50625 %
known_delta = [0] * 16
for q in xrange(1, 5):
NQ = 0
b = 2**q
testkey = randvec()
for i in xrange(16):
testkey[i] -= testkey[i] % b
testkey[i] += (1 - i) % b
s = ksa(testkey)
stream0 = prga0(s, 16)
assert bexact(b, testkey)
assert bconserving(b, s)
offs = set(range(2**16))
data = {i: (0, 0) for i in offs}
bounds = None
if q == 1: bounds = [23, 23, 24, 25, 25, 25, 25, 25, 25, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 26]
if q == 2: bounds = [13, 14, 17, 17, 19] + [19] * 100
if q == 3: bounds = [8, 9, 10, 12, 14, 16] + [17] * 100
if q == 4: bounds = [8, 8, 9, 10, 11] + [12] * 100
if bounds is None: break
for bound in bounds:
add_scores()
offs = filter(lambda off: getavg(off) >= bound, offs)
print len(offs), "nq", NQ, "total", NQALL
if len(offs) <= 1:
break
assert len(offs) == 1, "fail"
ans = bits16[offs.pop()]
bitmask = 1 << (q - 1)
for i in xrange(16):
known_delta[i] |= ans[i] << (q - 1)
print "known_delta", known_delta, "after", "q =", q
def test_key(key):
for delta, out in zip(deltas, outs):
test = [x^y for x, y in zip(key, delta)]
s = ksa(test)
stream = prga(s, 16)
if tuple(stream) != tuple(out):
break
return True
deltas = [randvec() for _ in xrange(512)]
outs = oracle_chunk(deltas)
b = 2**4
key = [0] * 16
for i in xrange(16):
key[i] += (1 - i) % b
key[i] ^= known_delta[i]
print "KEY PART", key
print "SEED", seed
for bits in bits16:
for i in xrange(16):
key[i] ^= key[i] & (1 << 4)
key[i] ^= bits[i] << 4
if test_key(key):
print "WIN", key
break
else:
print 'no match..'
f.read_until("2. Guess the key.\n")
f.send("2\n")
k = "".join(map(chr, key))
f.send(k)
f.interact()
#-*- coding:utf-8 -*-
import sys
import random
from zer0C4 import ksa, prga, N, mask
def prga0(s, n):
"""weakened prga"""
i = j = 0
res = bytearray()
for _ in range(n):
i = (i + 1) & mask
j = (j + s[i]) & mask
# s[i], s[j] = s[j], s[i]
res.append(s[(s[i]+s[j])&mask])
return tuple(res)
def randvec():
return [random.randint(0, 31) for _ in xrange(16)]
secret_key = randvec()
def calc_score(stream, stream0):
score = 0
step = 0.2
wt = 1 + step * 16
nmatch = 0
for c0, c in zip(stream0, stream):
ok = ((c % b) == (c0 % b))
nmatch += ok
score += wt * ok
wt -= step
return score
q = int(sys.argv[1])
b = 2**q
l = 16
stream0 = None
lst = []
lstbad = []
for i in xrange(16000):
key = randvec()
for i in xrange(16):
key[i] -= key[i] % b
key[i] += (1 - i) % b
s = ksa(key)
if stream0 is None:
stream0 = prga0(s, 16)
stream = prga(s, 16)
score = calc_score(stream, stream0)
lst.append(score)
key = randvec()
for i in xrange(16):
key[i] -= key[i] % (b/2)
key[i] += (1 - i) % (b/2)
s = ksa(key)
stream = prga(s, 16)
score = calc_score(stream, stream0)
lstbad.append(score)
lst.sort()
lstbad.sort()
print "avg", sum(lst) / float(len(lst))
print "avgbad", sum(lstbad) / float(len(lstbad))
print "med", lst[len(lst)/2]
print "medbad", lstbad[len(lstbad)/2]
for x in xrange(30):
print "%2d" % x, ":",
print "%5d" % sum(1 for v in lst if v >= x),
print "%5d" % sum(1 for v in lstbad if v >= x),
print " | ",
print "%6.2f%%" % (sum(1 for v in lst if v >= x) / float(len(lst)) * 100),
print "%6.2f%%" % (sum(1 for v in lstbad if v >= x) / float(len(lstbad)) * 100),
print
#!/usr/bin/env python
# coding=utf-8
import string
import random
from os import urandom
from hashlib import sha256
from sys import argv
from struct import unpack
from SocketServer import ThreadingTCPServer, BaseRequestHandler, socket
N = 5
mask = (1 << N) - 1
klen = 16
def ksa(key):
"""Key-scheduling algorithm for 0ops Cipher 4"""
global N
s = range(1 << N)
i = 0
j = 0
while 1:
i = (i + 1) & mask
j = (j + s[i] + key[i%len(key)]) & mask
s[i], s[j] = s[j], s[i]
if not i:
break
return s
def prga(s, n):
"""Pseudo-random generation algorithm for 0ops Cipher 4"""
i = 0
j = 0
res = bytearray()
for _ in range(n):
i = (i + 1) & mask
j = (j + s[i]) & mask
s[i], s[j] = s[j], s[i]
res.append(s[(s[i]+s[j])&mask])
return res
class zer0C4Handler(BaseRequestHandler):
def proof_of_work(self):
proof = ''.join([random.choice(string.ascii_letters+string.digits) for _ in xrange(20)])
digest = sha256(proof).hexdigest()
self.request.send("sha256(XXXX+%s) == %s\n" % (proof[4:],digest))
self.request.send('Give me XXXX:')
x = self.request.recv(4)
if len(x) != 4 or sha256(x+proof[4:]).hexdigest() != digest:
return False
return True
def setup(self):
self.ori_key = [ord(i) & mask for i in urandom(klen)]
self.key = self.ori_key[:]
def handle(self):
try:
self.core_handle()
except socket.error as e:
print e
def core_handle(self):
if not self.proof_of_work():
return
self.request.send("Welcome to 0C4 Blackbox Server! Can you guess the key? (We are too lazy, so we provide you xor-key. Please encrypt your message with it by yourself:P)\n")
for _ in xrange(1500):
self.request.send("1. Generate new xor-key.\n2. Guess the key.\n")
cmd = self.request.recv(2)
if not cmd:
break
if cmd[0] == '1':
self.request.send("Feel free to send some bytes:)\n")
size = unpack('H', self.request.recv(2))[0]
if size > 8192 or size % klen != 0:
self.request.send("Invalid size!\n")
break
deltas = bytearray()
while size:
recv_bytes = self.request.recv(size)
deltas += recv_bytes
size -= len(recv_bytes)
deltas = [i & mask for i in deltas]
xor_key = bytearray()
key = self.key[:]
for i in xrange(0, len(deltas), klen):
delta = deltas[i:i+klen]
key = [ii^jj for ii,jj in zip(key, delta)]
sbox = ksa(key)
xor_key += prga(sbox, 16)
self.key = key
self.request.send("Here is you xor-key: ")
self.request.send(xor_key)
elif cmd[0] == '2':
self.request.send("Input the original key: ")
guess_key = self.request.recv(klen)
if map(ord, guess_key) == self.ori_key:
self.request.send("Here is what you want: {}\n".format(flag))
else:
self.request.send("Wrong! You won't get anything:(\n")
break
else:
self.request.send("Invalid command!\n")
self.request.send("Bye!\n")
if __name__ == '__main__':
from flag import flag
ThreadingTCPServer.allow_reuse_address = True
if len(argv) < 3:
print "Usage: {} <IP> <PORT>".format(argv[0])
else:
ip = argv[1]
port = int(argv[2])
s = ThreadingTCPServer((ip, port), zer0C4Handler)
try:
s.serve_forever()
except KeyboardInterrupt:
print "shut down!"
s.shutdown()
s.socket.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment