Created
August 26, 2022 18:49
-
-
Save kmdouglass/ceb6aba85852820831c2f5680cbd73a2 to your computer and use it in GitHub Desktop.
Round a list of floats to ints so that the sum of the ints is the integer portion of the sum of the floats
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def safe_round(array: npt.ArrayLike, total: int) -> np.ndarray: | |
"""Rounds an array of floats, maintaining their integer sum.""" | |
array = np.asanyarray(array) | |
# Round the array to the nearest integer | |
rounded_array: np.ndarray = np.rint(array) | |
error = total - np.sum(rounded_array) | |
if error == 0: | |
return rounded_array | |
# The number of elements to adjust. For integers, each element after rounding is within 0.5 of | |
# the desired value, so the maximum adjustment is 1. | |
n = int(np.abs(error)) | |
# np.argsort() returns an array of indices that would sort an array | |
sorted_index_array = np.argsort(array - rounded_array, axis=None) | |
# Add +/- 1 to the elements of the rounded_array with the n largest rounding errors | |
safe_rounded_array = rounded_array.flatten() | |
safe_rounded_array[sorted_index_array[0:n]] += np.copysign(1, error) | |
return safe_rounded_array.reshape(array.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment