Created
June 14, 2023 00:18
-
-
Save elisohl-ncc/0ffced62c5fb4afcdcd05bb925f73d5d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Crypto.Cipher import AES | |
from Crypto.Util.Padding import pad, unpad | |
from random import random, seed, randrange | |
from math import log2 | |
KEY = bytes(16) | |
IV = bytes(range(16)) | |
PT = b'Plaintext!' | |
CT = AES.new(key=KEY, iv=IV, mode=AES.MODE_CBC).encrypt(pad(PT, block_size=16)) | |
FP_RATE = FN_RATE = 0.4 | |
target_confidence_threshold = 0.9999 | |
#seed(b'bayes'*17) | |
num_queries = [] | |
class ByteSearch: | |
def __init__(self, oracle, confidence_threshold=0.9, quiet=True): | |
self._counter = 0 | |
self.oracle = oracle | |
self.queries = [[] for _ in range(256)] | |
self.confidences = [1/256]*256 | |
self.confidence_threshold = confidence_threshold | |
self.quiet = quiet | |
def update_confidences(self, index, result): | |
"""Given an oracle result for a given byte, update the confidences for each byte.""" | |
self.confidences = self.get_updated_confidences(self.confidences, index, result) | |
def pick_exhaustive(self): | |
return self._counter % 256 | |
def pick_by_confidence(self): | |
"""Pick a byte to test based on the current confidences.""" | |
return max(range(256), key=lambda i: self.confidences[i]) | |
def pick_by_entropy(self): | |
"""Pick a byte to test based on expected reduction in entropy.""" | |
# NOTE: VERY SLOW - for demo, try replacing 256 with 16 here and in randrange | |
entropies = [] | |
for i in range(256): | |
e_if_t = self.get_entropy(self.get_updated_confidences(self.confidences, i, True)) | |
e_if_f = self.get_entropy(self.get_updated_confidences(self.confidences, i, False)) | |
p_t = self.confidences[i] | |
p_f = 1 - p_t | |
entropies.append(p_t * e_if_t + p_f * e_if_f) | |
return min(range(256), key=lambda i: entropies[i]) | |
def query_byte(self, index): | |
"""Query the oracle for a given byte.""" | |
self._counter += 1 | |
result = self.oracle(index) | |
self.queries[index].append(result) | |
self.update_confidences(index, result) | |
if not self.quiet and self._counter & 0xFF == 0: | |
print(end=".", flush=True) | |
return result | |
def search(self, strategy): | |
"""Search for the plaintext byte by querying the oracle.""" | |
threshold = self.confidence_threshold | |
while max(self.confidences) < threshold: | |
self.query_byte(strategy()) | |
num_queries.append(sum(len(l) for l in self.queries)) | |
return max(range(256), key=lambda i: self.confidences[i]) | |
@staticmethod | |
def bayes(h, e_given_h, e_given_not_h): | |
"""Update the posterior probability of h given e. | |
e: evidence | |
h: hypothesis | |
e_given_h: probability of e given h | |
e_given_not_h: probability of e given not h | |
""" | |
return e_given_h * h / (e_given_h * h + e_given_not_h * (1 - h)) | |
@staticmethod | |
def get_updated_confidences(confidences, index, result): | |
new_confidences = confidences[:] # shallow copy | |
for j in range(256): | |
p_h = confidences[j] | |
if index == j: | |
p_e_given_h = 1 - FN_RATE if result else FN_RATE | |
p_e_given_not_h = FP_RATE if result else 1 - FP_RATE | |
else: | |
p_e_given_h = FP_RATE if result else 1 - FP_RATE | |
p_hi_given_not_hj = confidences[index] / (1 - confidences[j]) | |
p_not_hi_given_not_hj = 1 - p_hi_given_not_hj | |
if result: | |
p_e_given_not_h = p_hi_given_not_hj * (1 - FN_RATE) + p_not_hi_given_not_hj * FP_RATE | |
else: | |
p_e_given_not_h = p_hi_given_not_hj * FN_RATE + p_not_hi_given_not_hj * (1 - FP_RATE) | |
new_confidences[j] = ByteSearch.bayes(p_h, p_e_given_h, p_e_given_not_h) | |
return new_confidences | |
@staticmethod | |
def get_entropy(dist): | |
return -sum(p * log2(p) for p in dist if p) | |
#### BASIC SINGLE BYTE SEARCH TEST | |
def test_single_byte_search(): | |
def oracle(index): | |
if index == TARGET_BYTE: | |
return random() > FN_RATE | |
return not (random() > FP_RATE) | |
def attack(): | |
search = ByteSearch(oracle, confidence_threshold=target_confidence_threshold, quiet=False) | |
result = search.search(search.pick_by_entropy) | |
print() | |
print(*sorted([len(l) for l in search.queries], reverse=True)) | |
return result | |
correct = total = 0 | |
num_queries = [] | |
while True: | |
TARGET_BYTE = randrange(256) | |
result = attack() | |
total += 1 | |
if result == TARGET_BYTE: | |
correct += 1 | |
accuracy = correct / total | |
avg_queries = sum(num_queries) / len(num_queries) | |
print(f"{accuracy=:.5f}\t{avg_queries=:.1f}\t{num_queries[-1]=}\n") | |
# Full Bayesian padding oracle attack | |
def test_padding_oracle_attack(): | |
def oracle(iv, ct): | |
aes = AES.new(key=KEY, iv=iv, mode=AES.MODE_CBC) | |
try: | |
unpad(aes.decrypt(ct), 16) | |
except ValueError: | |
return False | |
return True | |
ORACLE_QUERIES = [0] | |
def worse_oracle(iv, ct): | |
ORACLE_QUERIES[0] += 1 | |
result = oracle(iv, ct) | |
if result: | |
return random() > FN_RATE | |
return not (random() > FP_RATE) | |
def attack(block, quiet=False): | |
D_k = [0]*16 | |
pad_len = 1 | |
while pad_len <= 16: | |
prefix = [0] * (16-pad_len) | |
postfix = [pad_len ^ i for i in D_k[17-pad_len:]] | |
# scan through candidate bytes | |
def wrapped_oracle(ind, double_up=False): # TODO double_up is dumb, just replace this with a pad_len == 1 test and make sure that works | |
iv = bytes(prefix + [ind] + postfix) | |
result = worse_oracle(iv, block) | |
if not result: | |
return False | |
if double_up: | |
prefix[-1] ^= 1 | |
iv_2 = bytes(prefix + [ind] + postfix) | |
result_2 = worse_oracle(iv_2, block) | |
if not result_2: | |
return False | |
return True | |
search = ByteSearch(wrapped_oracle, confidence_threshold=0.999) | |
result = search.search(search.pick_by_entropy) | |
D_k[-pad_len] = result ^ pad_len | |
if not quiet: print(end=f"{pad_len} ", flush=True) | |
pad_len += 1 # TODO add support for backtracking? or don't? | |
if not quiet: print() | |
return D_k | |
query_counts = [] | |
while True: | |
ORACLE_QUERIES[0] = 0 | |
print(bytes(a ^ b for a, b in zip(IV, attack(CT)))) | |
print(f"{ORACLE_QUERIES[0]=}") | |
query_counts.append(ORACLE_QUERIES[0]) | |
AVG_ORACLE_QUERIES = sum(query_counts) / len(query_counts) | |
print(f"{AVG_ORACLE_QUERIES=}") | |
print() | |
if __name__ == "__main__": | |
test_padding_oracle_attack() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment