Skip to content

Instantly share code, notes, and snippets.

@eric-czech
Created July 23, 2020 19:23
Show Gist options
  • Save eric-czech/4c5ad3356324f78e420139f79c8f6cc7 to your computer and use it in GitHub Desktop.
Save eric-czech/4c5ad3356324f78e420139f79c8f6cc7 to your computer and use it in GitHub Desktop.
Function to partition n elements into g groups (like np.array_split without materialization)
def dividx(n, groups):
"""Create index for groups that partition an array
The number of elements placed into a group will
either equal `n//groups` or `n//groups + 1`, depending
on how many of the latter are necessary to make
the partitioning complete.
Parameters
----------
n : int
Number of elements to partition
groups: int
Number of groups to partition over
Examples
--------
>>> dividx(7, 2) # Divide 7 elements into 2 groups
array([0, 0, 0, 0, 1, 1, 1])
>>> dividx(7, 3) # Divide 7 elements into 3 groups
array([0, 0, 0, 1, 1, 2, 2])
>>> dividx(7, 1) # Divide 7 elements into 1 group
array([0, 0, 0, 0, 0, 0, 0])
>>> dividx(7, 7) # Divide 7 elements into 7 groups
array([0, 1, 2, 3, 4, 5, 6])
Raises
------
ValueError if `groups` > `n` or `n` < 0
Returns
-------
ndarray
Array of size `n` with integer values in [0, `groups`)
"""
if groups > n:
raise ValueError(
f"Number of groups ({groups}) cannot be greater than number of elements ({n})"
)
if n < 0:
raise ValueError(f"Number of elements ({n}) cannot be negative")
n_div, n_mod = np.divmod(n, groups)
repeats = n_mod * [n_div + 1] + (groups - n_mod) * [n_div]
return np.repeat(np.arange(groups), repeats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment