Last active
June 29, 2020 00:29
-
-
Save fubuloubu/495ccb9d0ee6681aa11ff180b4b9d33e to your computer and use it in GitHub Desktop.
Implementation of overflow-safe version of exponentiation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implementation of overflow-safe version of exponentiation | |
# Prototyped for the EVM environment of Vyper | |
# from https://en.wikipedia.org/wiki/Exponentiation_by_squaring | |
import math | |
import pytest | |
from hypothesis import given, strategies as st, settings | |
global max_rounds | |
max_rounds = 0 | |
def power(a: int, b: int) -> int: | |
# Easy cases | |
# TODO: Adjust for EVM oddities | |
if a == 0 and b != 0: | |
return 0 | |
if a == 1 or b == 0: | |
return 1 | |
if b == 1: | |
return a | |
if a == -1: | |
return 1 if b % 2 == 0 else -1 | |
if b < 0 or b >= 256: # Sanity check on arg | |
raise ValueError | |
x = a | |
n = b | |
y = 1 | |
global max_rounds # For keeping track of O(log(n)) claim | |
num_rounds = 0 | |
# TODO: Adjust for EVM oddities | |
while n > 1: | |
# Overflow check on x ** 2 | |
# NOTE: x ** 2 < -(2 ** 127) is impossible | |
if x ** 2 >= 2 ** 256: | |
raise ValueError | |
# Overflow check on x * y | |
if x * y < -(2 ** 127) or x * y >= 2 ** 256: | |
raise ValueError | |
if n % 2 == 0: # n is even | |
x = x ** 2 | |
n = n // 2 | |
else: | |
y = x * y | |
x = x ** 2 | |
n = (n - 1) // 2 | |
if num_rounds > max_rounds: | |
max_rounds = num_rounds | |
num_rounds += 1 | |
# Overflow check on x * y | |
if x * y < -(2 ** 127) or x * y >= 2 ** 256: | |
raise ValueError | |
return x * y | |
# Adapt base strategy to be reasonable with given value of power_st produces | |
# NOTE: Still allow some overflow/underflow cases, but make it more balanced | |
@st.composite | |
def base_and_power(draw, n=st.integers(min_value=0, max_value=256)): # noqa: B008 | |
n = draw(n) | |
x = draw( | |
st.integers( | |
# pulls in-range number >50% of the time (50% + 2 / 257 chance) | |
min_value=-round(2 * (n ** (math.log(2 ** 127, n) / n))) if n > 1 else -(2 ** 127), | |
# pulls in-range number >50% of the time (50% + 2 / 257 chance) | |
max_value=round(2 * (n ** (math.log(2 ** 256, n) / n))) if n > 1 else 2 ** 256 - 1, | |
) # pulls in-range number >50% * >50% = >25% of the time | |
) | |
return (x, n) | |
@given(xn=base_and_power()) | |
@settings(max_examples=1000000) | |
def test_power(xn): | |
x, n = xn | |
if x ** n < -(2 ** 127) or x ** n >= 2 ** 256: | |
with pytest.raises(ValueError): | |
power(x, n) | |
else: | |
# TODO: Adjust for EVM oddities | |
assert power(x, n) == x ** n | |
global max_rounds # For keeping track of O(log(n)) claim | |
assert max_rounds <= 8 # log_2(256) = 8 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment