Skip to content

Instantly share code, notes, and snippets.

@thomasaarholt
Created January 12, 2021 11:20
Show Gist options
  • Save thomasaarholt/047f4455f6453367c49c944a20e94e13 to your computer and use it in GitHub Desktop.
Save thomasaarholt/047f4455f6453367c49c944a20e94e13 to your computer and use it in GitHub Desktop.
Row roll of an ndarray with optional fill_value=np.nan
def row_roll(arr, shifts, axis=1, fill=np.nan):
"""Apply an independent roll for each dimensions of a single axis.
Parameters
----------
arr : np.ndarray
Array of any shape.
shifts : np.ndarray, dtype int. Shape: `(arr.shape[:axis],)`.
Amount to roll each row by. Positive shifts row right.
axis : int
Axis along which elements are shifted.
fill: bool or float
If True, value to be filled at missing values. Otherwise just rolls across edges.
"""
if np.issubdtype(arr.dtype, int) and isinstance(fill, float):
arr = arr.astype(float)
shifts2 = shifts.copy()
arr = np.swapaxes(arr,axis,-1)
all_idcs = np.ogrid[[slice(0,n) for n in arr.shape]]
# Convert to a positive shift
shifts2[shifts2 < 0] += arr.shape[-1]
all_idcs[-1] = all_idcs[-1] - shifts2[:, np.newaxis]
result = arr[tuple(all_idcs)]
if fill is not False:
# Create mask of row positions above negative shifts
# or below positive shifts. Then set them to np.nan.
*_, nrows, ncols = arr.shape
mask_neg = shifts < 0
mask_pos = shifts >= 0
shifts_pos = shifts.copy()
shifts_pos[mask_neg] = 0
shifts_neg = shifts.copy()
shifts_neg[mask_pos] = ncols+1 # need to be bigger than the biggest positive shift
shifts_neg[mask_neg] = shifts[mask_neg] % ncols
indices = np.stack(nrows*(np.arange(ncols),))
nanmask = (indices < shifts_pos[:, None]) | (indices >= shifts_neg[:, None])
result[nanmask] = fill
arr = np.swapaxes(result,-1,axis)
return arr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment