Created
September 14, 2017 21:10
-
-
Save npyoung/8c811725a37a85a1cf51cde4079ce794 to your computer and use it in GitHub Desktop.
Triggered selection from a Numpy array
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def pst(x, triggers, window, axis=-1): | |
"""Extract windows along an axis centered around a list of indices. | |
Args: | |
x: Array from which windows should be extracted. | |
triggers: Indices in signal where the extracted sequences should be aligned to. | |
window: (before, after) tuple specifying the number of samples around | |
the trigger to extract. | |
axis: Axis along which to compute the PSTH. | |
Returns: | |
An array where the given axis is replaced by an (window_length X N_triggers) array | |
for extracted windows from the original array. E.g. if your input array has | |
shape (50, 60, 70) with 5 triggers, a 3-sample window, and axis is specified as 1, | |
then the output has shape (50, 3, 5, 70). This is useful for, say, plotting a PSTH | |
of your data. | |
""" | |
x = np.moveaxis(signal, axis, -1) | |
winlen = window[1] - window[0] | |
result = np.empty(x.shape[:-1] + (winlen, len(triggers))) | |
padding = [(0, 0) for _ in x.shape] | |
padding[-1] = (np.maximum(0, -window[0]), np.maximum(0, window[1])) | |
x = np.pad(x, padding, mode='constant', constant_values=np.nan) | |
shift = np.maximum(0, -window[0]) | |
for idx, trigger in enumerate(triggers): | |
result[...,:,idx] = x[...,shift + trigger + window[0] : shift + trigger + window[1]] | |
axabs = axis % signal.ndim | |
return np.moveaxis(result, (-2, -1), (axabs, axabs + 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment