Created
November 10, 2021 08:41
-
-
Save cheind/8f795e49192bf1da394d099a287c74ff to your computer and use it in GitHub Desktop.
(Batched) Sample Entropy in PyTorch for measuring complexities of time-series (see https://en.wikipedia.org/wiki/Sample_entropy)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def sample_entropy( | |
x: torch.Tensor, m: int = 2, r: float = None, stride: int = 1, subsample: int = 1 | |
): | |
"""Returns the (batched) sample entropy of the given time series. | |
Sample entropy is a measure of complexity of sequences that can be related | |
to predictability. Sample entropy (SE) is defined as the negative logarithm of | |
the following ratio: | |
SE(X,m,r) = -ln(C(X, m+1, r) / C(X, m, r)) | |
where C(X,m,r) is the number of partial vectors of length m in sequence X whose | |
Chebyshev distance is less than r. | |
Note, `0 <= SE >= -ln(2/[(T-m-1)(T-m)])`, where T is the sequence length | |
Based on | |
Richman, J. S., & Moorman, J. R. (2000). Physiological time-series analysis | |
using approximate entropy and sample entropy. | |
Params | |
------ | |
x: (B,T) tensor | |
Batched time-series | |
m: int | |
Embedding length | |
r: float | |
Distance threshold, if None then will be computed as `0.2std(x)` | |
stride: int | |
Step between embedding vectors | |
subsample: int | |
Reduce the number of possible vectors of length m. | |
Returns | |
------- | |
SE: (B,) tensor | |
Sample entropy for each sequence | |
""" | |
x = torch.atleast_2d(x) | |
B, T = x.shape | |
if r is None: | |
r = torch.std(x) * 0.2 | |
def _num_close(elen: int): | |
unf = x.unfold(1, elen, stride) # B,N,elen | |
if subsample > 1: | |
unf = unf[:, ::subsample, :] | |
N = unf.shape[1] | |
d = torch.cdist(unf, unf, p=float("inf")) # B,N,N | |
idx = torch.triu_indices(N, N, 1) # take pairwise distances excl. diagonal | |
C = (d[:, idx[0], idx[1]] < r).sum(-1) # B | |
return C | |
A = _num_close(m + 1) | |
B = _num_close(m) | |
# Exception handling, return upper bound. No regularities found | |
mask = torch.logical_or(A == 0, B == 0) | |
A[mask] = 2.0 | |
B[mask] = (T - m - 1) * (T - m) | |
return -torch.log(A / B) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import complexity | |
def test_sample_entropy(): | |
# Uniform random | |
x = torch.rand(10, 1024) | |
se = complexity.sample_entropy(x) | |
assert se.mean() >= 2.0 | |
# Straight lines | |
x = torch.arange(2 ** 12).float() | |
se = complexity.sample_entropy(x).mean() | |
assert abs(se) < 1e-3 | |
# Sine | |
x = torch.sin(torch.linspace(0, 10 * 3.145, 2 ** 12)) | |
se = complexity.sample_entropy(x).mean() | |
assert se < 0.2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment