Last active
July 25, 2023 06:53
-
-
Save RobinLinus/91e9b18b0bbd69e19429fe4544b0dcdb to your computer and use it in GitHub Desktop.
Emulate a Uint32 number type in a subfield of Cairo's `felt` type
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 32-Bit Arithmetic native to Cairo's Finite Field | |
# | |
# A collection of operations for 32-bit arithmetic performed in the | |
# exponents of elements of Cairo's finite field. | |
# | |
# The field's modulus is | |
# 2**251 + 17 * 2**192 + 1 == 2**192 * 5 * 7 * 98714381 * 166848103. | |
# so it contains multiplicative subgroups of sizes | |
# 2**192, 2**191, 2**190, ..., 2, 5, 7, 98714381, 166848103, ... | |
# | |
# We use elements of the subgroup with 2**32 elements | |
# to perform computations mod 2**32. | |
# | |
# | |
# Cairo's finite field | |
# | |
# The field's modulus p | |
p = 2**251 + 17*2**192 + 1 | |
# A generator of the field's multiplicative group | |
# See: https://github.com/starkware-libs/cairo-lang/blob/54d7e92a703b3b5a1e07e9389608178129946efc/src/starkware/cairo/stark_verifier/core/utils.cairo#L5 | |
g_field = 3 | |
# | |
# The Uint32 type in the subfield of size 2**32 | |
# | |
# A generator of the 32-bit additive subgroup of F_(p-1) | |
g_add = (p-1) // 2**32 | |
# A generator of the 32-bit multiplicative subgroup of F_p | |
g = pow(g_field, g_add, p) | |
# Convert a scalar to an Uint32 number | |
def to_uint32(z): | |
return pow(g, z, p) | |
# Add two Uint32 numbers | |
def uint32_add(a, b): | |
return a * b % p | |
# Multiply an Uint32 number by a scalar | |
def uint32_mul_scalar(a,b): | |
return pow(a, b, p) | |
# Compute the additive inverse of an Uint32 mod 2**32 | |
def uint32_neg(z): | |
return uint32_mul_scalar(z, 2**32 - 1) | |
# Subtract two Uint32 numbers | |
def uint32_sub(a, b): | |
return uint32_add(a, uint32_neg(b)) | |
# Flip all bits of an Uint32 | |
def uint32_bitwise_not(a): | |
ALL_ONES = to_uint32(2**32 - 1) | |
return uint32_sub(ALL_ONES, a) | |
# Perform a logical left shift | |
# Shift all bits of an Uint32 t bits to the left and fill up with zeros on the right | |
def uint32_left_shift(a, t): | |
result = a | |
for _ in range(t): | |
result = result * result % p | |
return result | |
# Generator of the subfield with 2 elements in the field with 2**32 elements "in the exponent" | |
g_2 = 2**32 // 2 | |
# Compute an Uint32's remainder when divided by 2 | |
def uint32_mod2(x): | |
return pow(x, g_2, p) # returns g^0=1 or g^(p-1)/2 | |
# Convert an Uint32 number back to a scalar | |
def from_uint32(z): | |
result = 0 | |
g_i = g_2 | |
pow_of_2 = 1 | |
# Compute bitwise the discrete logarithm | |
# using mod 2**i and set every "1" to "0" to | |
# produce one more trailing zero each step | |
for i in range(32): | |
z_i = pow(z, g_i, p) | |
if z_i != 1: | |
result = result + pow_of_2 | |
z = uint32_sub(z, to_uint32(pow_of_2)) | |
g_i = g_i // 2 | |
pow_of_2 = pow_of_2 * 2 | |
return result | |
# Multiply two Uint32 numbers | |
def uint32_mul(a, b): | |
return uint32_mul_scalar(a, from_uint32(b)) | |
# Validate that a scalar is in the 32-bit range | |
# 0 < x < (2**32 - 1) | |
# | |
def is_uint32(a): | |
return a == from_uint32(to_uint32(a)) | |
# Shift all bits t steps to the right. This works only with trailing zeros. | |
def uint32_right_shift_zeros(z, t): | |
result = 1 | |
base = 2 ** t | |
g_i = g_2 // base | |
pow_of_2 = 1 | |
# Compute bitwise the discrete logarithm as in from_uint32 | |
for i in range(32-t): | |
z_i = pow(z, g_i, p) | |
if z_i != 1: | |
result = uint32_add(result, to_uint32(pow_of_2)) | |
z = uint32_sub(z, to_uint32(pow_of_2 * base)) | |
g_i = g_i // 2 | |
pow_of_2 = pow_of_2 * 2 | |
return result | |
# Rotate all bits t steps to the right | |
def uint32_rotate_right(z, t): | |
mod_t_padded = uint32_mul_scalar(z, 2**(32-t) ) | |
mod_t = uint32_right_shift_zeros(mod_t_padded, 32-t) | |
div_t = uint32_right_shift_zeros(uint32_sub(z, mod_t), t) | |
return uint32_add(mod_t_padded, div_t) | |
# | |
# | |
# Tests and sanity checks | |
# | |
# | |
print('\nAddition') | |
a = to_uint32(42) | |
b = to_uint32(2**32 - 23) | |
c = uint32_add(a,b) | |
c_expected = to_uint32( (42 + 2**32-23) % 2**32 ) | |
print(c, c_expected) | |
print('\nSubtraction') | |
a = to_uint32(23) | |
b = to_uint32(42) | |
c = uint32_sub(a,b) | |
c_expected = to_uint32( (2**32 + 23 - 42) % 2**32 ) | |
print(c, c_expected) | |
print('\nMultiplication by a scalar') | |
a = to_uint32(23) | |
b = 42 | |
c = uint32_mul_scalar(a,b) | |
c_expected = to_uint32( 23*42 ) | |
print(c, c_expected) | |
print('\nBitwise not') | |
a = to_uint32(0b00010101) | |
c = uint32_bitwise_not(a) | |
c_expected = pow(g, 0b11111111111111111111111111101010, p) | |
print(c, c_expected) | |
print('\nShift all bits to the left') | |
a = to_uint32(0b11111111111111111111111111101010) | |
t = 3 | |
c = uint32_left_shift(a, t) | |
c_expected = pow(g, 0b11111111111111111111111101010000, p) | |
print(c, c_expected) | |
print('\nCompute modulo 2') | |
# Even | |
a = to_uint32(0b11111111111111111111111111101010) | |
c = uint32_mod2(a) | |
print('even', c == 1) | |
# Odd | |
a = to_uint32(0b11111111111111111111111111101011) | |
c = uint32_mod2(a) | |
print('odd ', c != 1) | |
print('\nEfficient Discrete Logarithm') | |
i = 1 | |
while i < 21: | |
z = to_uint32(i+100000) | |
s = from_uint32(z) | |
print(s, i, z) | |
i += 1 | |
print('\nMultiplication') | |
a = to_uint32(23) | |
b = to_uint32(42) | |
c = uint32_mul(a,b) | |
c_expected = to_uint32( 23*42 ) | |
print(c, c_expected) | |
print('\nProve a 32-bit range') | |
x = 424242 | |
x_evil = x + 2**32 | |
print( is_uint32(x), is_uint32(x_evil) ) | |
print('\nShift trailing zeros to the right') | |
a = to_uint32(0b10001101000001100000000000000000) | |
c = uint32_right_shift_zeros(a, 14) | |
print( "{0:b}".format( from_uint32(c)) ) | |
a = to_uint32(0b10000000000000000000000000000000) | |
c = uint32_right_shift_zeros(a, 29) | |
print( "{0:b}".format( from_uint32(c)) ) | |
print('\nRotate right') | |
a = to_uint32(0b11111111000000001111011100110001) | |
c = uint32_rotate_right(a, 7) | |
c_expected = 0b01100011111111100000000111101110 | |
print( "{0:b}".format( from_uint32(c)) ) | |
print( "{0:b}".format( c_expected) ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://gist.github.com/RobinLinus/91e9b18b0bbd69e19429fe4544b0dcdb#file-uint32_in_exponent-py-L89
This dlog requires only exponentiations with constants. In combination with hints, we can use it to emulate exponentiations with variables.