Created
August 7, 2024 12:51
-
-
Save cre-mer/7ae562caff81009d4d3377e7572881b8 to your computer and use it in GitHub Desktop.
Circuit to prove that a list is sorted in an ascending sequence using only addition, multiplication, and equality checks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sorted_list = [3, 9, 10, 404, 1337] | |
unsorted_list = [3, 9, 10, 1337, 404] | |
# define default constant | |
MAX = 4 | |
def decimal_to_binary(dec, max = MAX): | |
""" | |
convert a decimal number `dec =< max` to binary, returning each binary index as a list | |
""" | |
# to simplify the code, we only allow positive numbers | |
assert dec >= 0, "dec value MUST be positive" | |
# convert decimal to binary representation and remove '0b' characters from binary representation | |
dec_as_bin = bin(dec)[2:] | |
assert len(dec_as_bin) <= max, f'dec value MUST be smaller than or equal to `max`. Expected: dec <= {max}, is {dec} > {max}' | |
return(dec_as_bin) | |
def validate_binaries(binaries, expected_value): | |
value = 0 | |
exp = len(binaries) - 1 | |
for binary in binaries: | |
# enforce that each binary is either 1 or 0 | |
assert int(binary) * (int(binary) - 1) == 0, 'invalid binary' | |
# calculate new value and exponant for the next round | |
if int(binary) == 1: | |
value += 2 ** exp | |
exp -= 1 | |
assert expected_value == value, f'values mismatch, expected {expected_value}, got {value}' | |
def define_midpoint(num_zeros = MAX + 1): | |
""" | |
define a midpoint | |
generate a number, where the binary representation has MSB == 1, and the rest of the bits are 0s | |
""" | |
binary = '1' + '0' * num_zeros | |
result = int(binary, 2) | |
return result | |
def compute_diff_relative_to_midpoint(midpoint, u, v): | |
""" | |
calculate midpoint + (u - v) | |
to avoid a range error, the binary representation of the midpoint MUST use at least 1 bit more than the binary representations of u and v | |
return MSB of midpoint + (u - v) | |
""" | |
midpoint_as_bin = bin(midpoint)[2:] | |
u_as_bin = bin(u)[2:] | |
v_as_bin = bin(v)[2:] | |
assert len(midpoint_as_bin) > len(u_as_bin), f'mipoint\'s binary representation must use 1 bit more than u' | |
assert len(midpoint_as_bin) > len(v_as_bin), f'mipoint\'s binary representation must use 1 bit more than v' | |
delta = u - v | |
mid_plus_delta = bin(midpoint + delta)[2:] | |
if len(mid_plus_delta) < len(midpoint_as_bin): | |
return 0 | |
else: | |
return 1 | |
def is_list_sorted(list): | |
MAX = 11 # hardcoded to fit at max 2047 | |
MIDPOINT = define_midpoint(MAX + 1) | |
prev_value = None | |
for value in list: | |
# 1. convert decimal to binary | |
value_as_bin = decimal_to_binary(value, MAX) | |
# 2. make sure the binaries are valid | |
validate_binaries(value_as_bin, value) | |
# 3. calculate u - v | |
if prev_value == None: # skip first item | |
prev_value = value | |
continue | |
msb = compute_diff_relative_to_midpoint(MIDPOINT, prev_value, value) | |
prev_value = value | |
if msb == 0: | |
continue | |
print(f'list: {list} is not sorted\n') | |
return False | |
print(f'list: {list} is sorted\n') | |
return True | |
assert is_list_sorted(sorted_list) == True, 'list should be sorted' | |
assert is_list_sorted(unsorted_list) == False, 'list should not be sorted' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment