Last active
March 22, 2018 16:38
-
-
Save dojeda/2250ad463a925a38d7179ec913167ebc to your computer and use it in GitHub Desktop.
Epoching of 2D array on its last dimension using a view (not a copy)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Small proof of concept of an epoching function using NumPy strides | |
License: BSD-3-Clause | |
Copyright: David Ojeda <[email protected]>, 2018 | |
""" | |
import numpy as np | |
from numpy.lib import stride_tricks | |
def epoch(a, size, interval, axis=-1): | |
""" Create a view of `a` as (possibly overlapping) epochs. | |
The intended use-case for this function is to epoch an array representing | |
a multi-channels signal with shape `(n_samples, n_channels)` in order | |
to create several smaller views as arrays of size `(size, n_channels)`, | |
without copying the input array. | |
This function uses a new stride definition in order to produce a view of | |
`a` that has shape `(num_epochs, ..., size, ...)`. Dimensions other than | |
the one represented by `axis` do not change. | |
Parameters | |
---------- | |
a: array_like | |
Input array | |
size: int | |
Number of elements (i.e. samples) on the epoch. | |
interval: int | |
Number of elements (i.e. samples) to move for the next epoch. | |
axis: int | |
Axis of the samples on `a`. For example, if `a` has a shape of | |
`(num_observation, num_samples, num_channels)`, then use `axis=1`. | |
Returns | |
------- | |
ndarray | |
Epoched view of `a`. Epochs are in the first dimension. | |
Examples | |
-------- | |
>>> x1 = np.arange(4*10).reshape(4,10) # example: 4 channels, 10 samples | |
>>> x1 | |
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19], | |
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29], | |
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]]) | |
>>> epoch(x1, 5, 2) # epochs of 5 samples every 2 samples | |
array([[[ 0, 1, 2, 3, 4], | |
[10, 11, 12, 13, 14], | |
[20, 21, 22, 23, 24], | |
[30, 31, 32, 33, 34]], | |
[[ 2, 3, 4, 5, 6], | |
[12, 13, 14, 15, 16], | |
[22, 23, 24, 25, 26], | |
[32, 33, 34, 35, 36]], | |
[[ 4, 5, 6, 7, 8], | |
[14, 15, 16, 17, 18], | |
[24, 25, 26, 27, 28], | |
[34, 35, 36, 37, 38]]]) | |
>>> x2 = np.arange(2*5*3).reshape(2,5,3) # example: 2 observations, 5 samples, 3 channels | |
>>> x2 | |
array([[[ 0, 1, 2], | |
[ 3, 4, 5], | |
[ 6, 7, 8], | |
[ 9, 10, 11], | |
[12, 13, 14]], | |
[[15, 16, 17], | |
[18, 19, 20], | |
[21, 22, 23], | |
[24, 25, 26], | |
[27, 28, 29]]]) | |
>>> epoch(x2, 3, 2, axis=1) # epochs of 3 samples every 2 samples | |
array([[[[ 0, 1, 2], | |
[ 3, 4, 5], | |
[ 6, 7, 8]], | |
[[15, 16, 17], | |
[18, 19, 20], | |
[21, 22, 23]]], | |
[[[ 6, 7, 8], | |
[ 9, 10, 11], | |
[12, 13, 14]], | |
[[21, 22, 23], | |
[24, 25, 26], | |
[27, 28, 29]]]]) | |
""" | |
a = np.asarray(a) | |
n_samples = a.shape[axis] | |
n_epochs = (n_samples - size) // interval + 1 | |
new_shape = list(a.shape) | |
new_shape[axis] = size | |
new_shape = (n_epochs,) + tuple(new_shape) | |
new_strides = (a.strides[axis] * interval,) + a.strides | |
return stride_tricks.as_strided(a, new_shape, new_strides) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment