Last active
February 3, 2023 16:20
-
-
Save cwindolf/737abf69c2251b9b733a168e92449a6b to your computer and use it in GitHub Desktop.
1D optionally normalized, optionally weighted, optionally centered cross-correlation in PyTorch (+ SciPy fallback), with API like F.conv1d
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
import torch | |
import torch.nn.functional as F | |
HAVE_TORCH = True | |
except ImportError: | |
HAVE_TORCH = False | |
def normxcorr1d( | |
template, | |
x, | |
weights=None, | |
centered=True, | |
normalized=True, | |
padding="same", | |
conv_engine="torch", | |
): | |
"""normxcorr1d: Normalized cross-correlation, optionally weighted | |
The API is like torch's F.conv1d, except I have accidentally | |
changed the position of input/weights -- template acts like weights, | |
and x acts like input. | |
Returns the cross-correlation of `template` and `x` at spatial lags | |
determined by `mode`. Useful for estimating the location of `template` | |
within `x`. | |
This might not be the most efficient implementation -- ideas welcome. | |
It uses a direct convolutional translation of the formula | |
corr = (E[XY] - EX EY) / sqrt(var X * var Y) | |
This also supports weights! In that case, the usual adaptation of | |
the above formula is made to the weighted case -- and all of the | |
normalizations are done per block in the same way. | |
Arguments | |
--------- | |
template : tensor, shape (num_templates, length) | |
The reference template signal | |
x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) | |
The signal in which to find `template` | |
weights : tensor, shape (length,) | |
Will use weighted means, variances, covariances if supplied. | |
centered : bool | |
If true, means will be subtracted (per weighted patch). | |
normalized : bool | |
If true, normalize by the variance (per weighted patch). | |
padding : int, optional | |
How far to look? if unset, we'll use half the length | |
conv_engine : string, one of "torch", "numpy" | |
What library to use for computing cross-correlations. | |
If numpy, falls back to the scipy correlate function. | |
Returns | |
------- | |
corr : tensor | |
""" | |
if conv_engine == "torch": | |
assert HAVE_TORCH | |
conv1d = F.conv1d | |
npx = torch | |
elif conv_engine == "numpy": | |
conv1d = scipy_conv1d | |
npx = np | |
else: | |
raise ValueError(f"Unknown conv_engine {conv_engine}") | |
x = npx.atleast_2d(x) | |
num_templates, length = template.shape | |
num_inputs, length_ = template.shape | |
assert length == length_ | |
# generalize over weighted / unweighted case | |
device_kw = {} if conv_engine == "numpy" else dict(device=x.device) | |
ones = npx.ones((1, 1, length), dtype=x.dtype, **device_kw) | |
no_weights = weights is None | |
if no_weights: | |
weights = ones | |
wt = template[:, None, :] | |
wt2 = npx.square(template[:, None, :]) | |
else: | |
assert weights.shape == (length,) | |
weights = weights[None, None] | |
wt = template[:, None, :] * weights | |
wt2 = npx.square(template)[:, None, :] * weights | |
# conv1d valid rule: | |
# (B,1,L),(O,1,L)->(B,O,L) | |
# compute expectations | |
# how many points in each window? seems necessary to normalize | |
# for numerical stability. | |
N = conv1d(ones, weights, padding=padding) | |
if centered: | |
Et = conv1d(ones, wt, padding=padding) / N | |
Ex = conv1d(x[:, None, :], weights, padding=padding) / N | |
# compute (weighted) covariance | |
# important: the formula E[XY] - EX EY is well-suited here, | |
# because the means are naturally subtracted correctly | |
# patch-wise. you couldn't pre-subtract them! | |
cov = conv1d(x[:, None, :], wt, padding=padding) / N | |
if centered: | |
cov -= Ex * Et | |
# compute variances for denominator, using var X = E[X^2] - (EX)^2 | |
if normalized: | |
var_template = conv1d( | |
ones, wt2, padding=padding | |
) / N | |
var_x = conv1d( | |
npx.square(x)[:, None, :], weights, padding=padding | |
) / N | |
if centered: | |
var_template -= npx.square(Et) | |
var_x -= npx.square(Ex) | |
# now find the final normxcorr | |
corr = cov # renaming for clarity | |
if normalized: | |
corr /= npx.sqrt(var_x * var_template) | |
# get rid of NaNs in zero-variance areas | |
corr[~npx.isfinite(corr)] = 0 | |
return corr | |
def scipy_conv1d(input, weights, padding="valid"): | |
"""SciPy translation of torch F.conv1d""" | |
from scipy.signal import correlate | |
n, c_in, length = input.shape | |
c_out, in_by_groups, kernel_size = weights.shape | |
assert in_by_groups == c_in == 1 | |
if padding == "same": | |
mode = "same" | |
length_out = length | |
elif padding == "valid": | |
mode = "valid" | |
length_out = length - 2 * (kernel_size // 2) | |
elif isinstance(padding, int): | |
mode = "valid" | |
input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)]) | |
length_out = length - (kernel_size - 1) + 2 * padding | |
else: | |
raise ValueError(f"Unknown padding {padding}") | |
output = np.zeros((n, c_out, length_out), dtype=input.dtype) | |
for m in range(n): | |
for c in range(c_out): | |
output[m, c] = correlate(input[m, 0], weights[c, 0], mode=mode) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment