Last active
August 19, 2021 19:20
-
-
Save spezold/c90b310de7f3245feb19a84f35ed3dc5 to your computer and use it in GitHub Desktop.
Return both the values inside and outside of a rolling window over a 1D PyTorch tensor as a 2D tensor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from torch import Tensor | |
# The straightforward solution | |
def rolling_window_inside_and_outside(t: Tensor, size: int, stride: int=1) -> Tuple[Tensor, Tensor]: | |
""" | |
Given a 1D tensor, provide both the values inside the rolling window and outside the rolling window for each window | |
position with the given window size and stride. | |
:param t: tensor to be rolled over | |
:param size: window size | |
:param stride: step size for rolling | |
:return: values inside the rolling window, values outside the rolling window (n-th window position in n-th row) | |
""" | |
assert t.ndim == 1 | |
len_t = len(t) | |
assert size <= len_t | |
# Pad to make necessary values available both inside and outside | |
t = torch.cat((t, t[:len_t - size])) | |
# Unfold completely | |
unfolded = t.unfold(0, len_t, stride) | |
# Split into inside and outside part | |
inside, outside = unfolded.split((size, len_t - size), dim=1) | |
return inside, outside | |
# The tedious solution: by unfolding the inner and outer part separately, we might perhaps save memory | |
def rolling_window_inside_and_outside_2(t: Tensor, size: int, stride: int=1) -> Tuple[Tensor, Tensor]: | |
""" | |
Given a 1D tensor, provide both the values inside the rolling window and outside the rolling window for each window | |
position with the given window size and stride. | |
:param t: tensor to be rolled over | |
:param size: window size | |
:param stride: step size for rolling | |
:return: values inside the rolling window, values outside the rolling window (n-th window position in n-th row) | |
""" | |
assert t.ndim == 1 | |
len_t = len(t) | |
assert size <= len_t | |
# Values inside window are straightforward: we just need to unfold | |
inside = t.unfold(0, size, stride) | |
# Values outside window need to be repeated or cropped, depending on window size | |
o = t.roll(-size) # Outer values start at value after window | |
size_o = len_t - size # Outer window size is length minus inner window size | |
len_o = 2 * size_o # Necessary length of values to unfold is now twice the window size | |
# Bring to required length (same as o = torch.cat((o, o[:len_o - len_t])) if (len_o > len_t) else o[:len_o]) | |
o = torch.cat((o[:len_o], o[:max(len_o - len_t, 0)])) | |
outside = o.unfold(0, size_o, stride) | |
return inside, outside | |
if __name__ == "__main__": | |
for t in [torch.arange(7), torch.arange(8)]: | |
for size in [0, 3, 4, 5, len(t)]: | |
for stride in [1,2,3]: | |
i, o = rolling_window_inside_and_outside(t, size=size, stride=stride) | |
print() | |
print() | |
print(f"Length {len(t)}, window size {size}, stride {stride}:") | |
print() | |
print("Inside:") | |
print(i) | |
print() | |
print("Outside:") | |
print(o) | |
# Concatenating the inner and outer values for each position should again give all values of t | |
concatenated = torch.cat((i, o), dim=1) | |
for row in range(len(concatenated)): | |
assert set(concatenated[row].tolist()) == set(t.tolist()) | |
# This prints for example: | |
# | |
# Length 8, window size 5, stride 2: | |
# | |
# Inside: | |
# tensor([[0, 1, 2, 3, 4], | |
# [2, 3, 4, 5, 6]]) | |
# | |
# Outside: | |
# tensor([[5, 6, 7], | |
# [7, 0, 1]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment