Last active
July 27, 2022 16:47
-
-
Save xvdp/baf3b1610cedb2b7a2684a187cc2645f to your computer and use it in GitHub Desktop.
cumdiv() and cumdif(): reciprocals to torch cumulative functions cumsum() and cumprod()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""@xvdp | |
reciprocals for torch.cumsum and torch.cumprod | |
I noticed that torch has cumsum and cumprod but not their reciprocals | |
even thought cumdif and cumdiv have meanings in analysis and probability | |
and are routinely used. | |
Are these interpretations correct? | |
> cumsum can be thought of as a discrete integral | |
> cumdif as discrete derivative | |
> cumprod is useful as the joint probability of a sequence of events | |
> cumdiv then is the marginal probability along that sequence | |
cumprod and cumdiv could also be expressed as exp(cumsum(log(x))) and exp(cumdif(log(x))) | |
see test_explog_interpretation() | |
""" | |
from typing import Optional | |
import torch | |
from torch.types import _dtype | |
import torch.nn.functional as F | |
from torch import Tensor | |
# pylint: disable=no-member | |
# pylint: disable=invalid-name | |
# pylint: disable=suppressed-message | |
def cumdiv(x: Tensor, | |
dim: int = 0, | |
*, | |
keepsize: bool = True, | |
dtype: Optional[_dtype] = None, | |
**kwargs) -> Tensor: | |
""" inverse cumprod - returns same size output as input | |
similar to exp(cumdif(log(x))) | |
could be thought as undoing the chain rule of probability, ie, marginalizing probabilities | |
Args | |
x Tensor | |
axis | dim int [0] | |
keepsize bool [True] | |
True: x.shape == out.shape; in this case cumdiv is the reciprocal of cumprod | |
False: resuts in a common use of cumdiv: x[1:]/x[:-1] reducing size by 1 | |
""" | |
denom, front_slice = _cuminv(x, dim, keepsize=keepsize, prod=1, **kwargs) | |
return x[front_slice].div(denom).to(dtype) | |
def cumdif(x: Tensor, | |
dim: int = 0, | |
*, | |
keepsize: bool = True, | |
dtype: Optional[_dtype] = None, | |
**kwargs) -> Tensor: | |
""" inverse cumsum - returns same size output as input | |
could be thought of as the derivative | |
Args | |
x Tensor | |
axis | dim int [0] | |
keepsize bool [True] | |
True: x.shape == out.shape; in this case cumdif is the reciprocal of cumsum | |
False: resuts in a common use of cumdiv: x[1:] - x[:-1] reducing size by 1 | |
""" | |
prev, front_slice = _cuminv(x, dim, keepsize=keepsize, prod=0, **kwargs) | |
return x[front_slice].sub(prev).to(dtype) | |
def _cuminv(x: Tensor, | |
dim: int = 0, | |
*, | |
keepsize: bool = True, | |
prod: int = 1, | |
**kwargs) -> Tensor: | |
""" inverse function base for cumdif and cumdiv | |
""" | |
axis = kwargs.get('axis') | |
dim = dim if axis is None else axis | |
_back_slice = [slice(0, None, None)] * x.ndim | |
_back_slice[dim] = slice(0, -1, None) | |
front_slice = [slice(0, None, None)] * x.ndim | |
other = x[_back_slice] | |
if keepsize: | |
_pads = [0]*(x.ndim*2) | |
_pads[(x.ndim - 1 - dim)*2] = 1 | |
other = F.pad(other, _pads, value=prod) | |
else: | |
front_slice[dim] = slice(1, None, None) | |
return other, front_slice | |
### | |
# tests that validate for 1,2,3 dimensions that these functions | |
# are reciprocals to the torch functions | |
# | |
def test_all(): | |
x = test_cumdiv() | |
x = test_cumdif() | |
test_axis_dim(x) | |
test_explog_interpretation(x) | |
def test_cumdiv(device=None): | |
""" test that cumdiv is the reciprocal of cumprod | |
""" | |
x = torch.linspace(0.1,1,10).to(device=device) | |
assert torch.allclose(cumdiv(torch.cumprod(x, 0), 0), x) | |
x = torch.stack((x, x.flip(0))) | |
for i in range(x.ndim): | |
assert torch.allclose(cumdiv(torch.cumprod(x, i), i), x) | |
x = torch.stack((x, x*2)) | |
for i in range(x.ndim): | |
assert torch.allclose(cumdiv(torch.cumprod(x, i), i), x) | |
# check keepsize=False; ~ x[1:] / x[:-1] | |
for i in range(x.ndim): | |
y = torch.cumprod(x, i) | |
z = cumdiv(y, i, keepsize=False) | |
_slice = [slice(0, None, None)]*i + [slice(1, None, None)] | |
assert torch.allclose(z, x[_slice]) | |
return x | |
def test_cumdif(device=None): | |
""" test that cumdif is the reciprocal of cumsum | |
""" | |
x = torch.linspace(0.1,1,10).to(device=device) | |
assert torch.allclose(cumdif(torch.cumsum(x, 0), 0), x) | |
x = torch.stack((x, x.flip(0))) | |
for i in range(x.ndim): | |
assert torch.allclose(cumdif(torch.cumsum(x, i), i), x) | |
x = torch.stack((x, x*2)) | |
for i in range(x.ndim): | |
assert torch.allclose(cumdif(torch.cumsum(x, i), i), x) | |
# check keepsize=False; ~ x[1:] - x[:-1] | |
for i in range(x.ndim): | |
y = torch.cumsum(x, i) | |
z = cumdif(y, i, keepsize=False) | |
_slice = [slice(0, None, None)]*i + [slice(1, None, None)] | |
assert torch.allclose(z, x[_slice]) | |
return x | |
def test_axis_dim(x): | |
"""overload hack to match overloads in torch cumsum/cumprod | |
axis == dim, in this implementation axis overrides dim if both present | |
""" | |
for i in range(x.ndim): | |
assert torch.allclose(cumdif(x, dim=i), cumdif(x, axis=i)) | |
def test_explog_interpretation(x): | |
""" mul(x) = exp(sum(log(x))""" | |
for i in range(x.ndim): | |
assert torch.allclose(torch.cumprod(x, i), torch.exp(torch.cumsum(torch.log(x), i))) | |
assert torch.allclose(cumdiv(x, i), torch.exp(cumdif(torch.log(x), i))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment