Created
April 9, 2017 16:24
-
-
Save tobywf/b0f67e9e414e41813c1590ed1ea39e2c to your computer and use it in GitHub Desktop.
Decimal rounding to the nearest multiple (plus unit tests)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from decimal import Decimal, ROUND_CEILING, ROUND_FLOOR | |
ONE = Decimal(1) | |
def round_up(number, quantum): | |
"""Round a decimal number up to the nearest multiple of quantum""" | |
return (number / quantum).quantize(ONE, rounding=ROUND_CEILING) * quantum | |
def round_down(number, quantum): | |
"""Round a decimal number down to the nearest multiple of quantum""" | |
return (number / quantum).quantize(ONE, rounding=ROUND_FLOOR) * quantum | |
import unittest | |
from decimal import Decimal, BasicContext, localcontext | |
ONE = Decimal(1) | |
class RoundTestCase: | |
def test_unity(self): | |
self.assertEqual(self.method(ONE, ONE), ONE) | |
def test_epsilon(self): | |
def _range_inclusive(start, end, step=1): | |
return range(start, end + step, step) | |
with localcontext(BasicContext) as context: | |
for precision in range(1, 20): | |
context.prec = precision + 2 # give the algo some room to work | |
for exponent in _range_inclusive(-10, 10): | |
epsilon = Decimal((0, (1, ), exponent - precision)) | |
for digit in _range_inclusive(1, 9): | |
quantum = Decimal((0, (digit, ), exponent)) | |
for multiple in _range_inclusive(-9, 9): | |
with self.subTest( | |
precision=precision, | |
exponent=exponent, | |
digit=digit, | |
multiple=multiple): | |
under = quantum.fma(multiple, -epsilon) | |
self.assertEqual(self.method(under, quantum), self.under(multiple) * quantum) | |
over = quantum.fma(multiple, epsilon) | |
self.assertEqual(self.method(over, quantum), self.over(multiple) * quantum) | |
class RoundUpTestCase(unittest.TestCase, RoundTestCase): | |
method = staticmethod(round_up) | |
@staticmethod | |
def under(multiple): | |
return multiple | |
@staticmethod | |
def over(multiple): | |
return multiple + 1 | |
class RoundDownTestCase(unittest.TestCase, RoundTestCase): | |
method = staticmethod(round_down) | |
@staticmethod | |
def under(multiple): | |
return multiple - 1 | |
@staticmethod | |
def over(multiple): | |
return multiple | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment