Created
July 23, 2020 19:23
-
-
Save eric-czech/4c5ad3356324f78e420139f79c8f6cc7 to your computer and use it in GitHub Desktop.
Function to partition n elements into g groups (like np.array_split without materialization)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def dividx(n, groups): | |
"""Create index for groups that partition an array | |
The number of elements placed into a group will | |
either equal `n//groups` or `n//groups + 1`, depending | |
on how many of the latter are necessary to make | |
the partitioning complete. | |
Parameters | |
---------- | |
n : int | |
Number of elements to partition | |
groups: int | |
Number of groups to partition over | |
Examples | |
-------- | |
>>> dividx(7, 2) # Divide 7 elements into 2 groups | |
array([0, 0, 0, 0, 1, 1, 1]) | |
>>> dividx(7, 3) # Divide 7 elements into 3 groups | |
array([0, 0, 0, 1, 1, 2, 2]) | |
>>> dividx(7, 1) # Divide 7 elements into 1 group | |
array([0, 0, 0, 0, 0, 0, 0]) | |
>>> dividx(7, 7) # Divide 7 elements into 7 groups | |
array([0, 1, 2, 3, 4, 5, 6]) | |
Raises | |
------ | |
ValueError if `groups` > `n` or `n` < 0 | |
Returns | |
------- | |
ndarray | |
Array of size `n` with integer values in [0, `groups`) | |
""" | |
if groups > n: | |
raise ValueError( | |
f"Number of groups ({groups}) cannot be greater than number of elements ({n})" | |
) | |
if n < 0: | |
raise ValueError(f"Number of elements ({n}) cannot be negative") | |
n_div, n_mod = np.divmod(n, groups) | |
repeats = n_mod * [n_div + 1] + (groups - n_mod) * [n_div] | |
return np.repeat(np.arange(groups), repeats) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment