Last active
June 24, 2021 16:53
-
-
Save cwindolf/15d4d0d8c744c8de030474cb7e1cbf75 to your computer and use it in GitHub Desktop.
Optionally FFT-based normalized cross-correlation in Python / NumPy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Translation of the Octave implementation (GPL), which is | |
# copyright 2014 Benjamin Eltzner. For that code, and license, see: | |
# hg.code.sf.net/p/octave/image/file/tip/inst/normxcorr2.m | |
import numpy as np | |
from scipy.signal import correlate, convolve | |
def normxcorr(template, x, mode="full", method="auto", assume_centered=False): | |
"""normxcorr: Normalized cross-correlation | |
Returns the cross-correlation of `template` and `x` at spatial lags | |
determined by `mode`. Useful for estimating the location of `template` | |
within `x`. | |
Arguments | |
--------- | |
template, x : np.array | |
Must have same dimensionality. | |
mode : one of "full", "valid", "same" | |
method : one of "auto", "direct", "fft" | |
For docs of these parameters, see scipy.signal.correlate. | |
assume_centered : bool | |
Avoid a copy if your data is centered already. | |
Returns | |
------- | |
corr : np.array | |
Array of same dimensionality as `template` and `x` containing | |
the subset of normalized cross-correlations corresponding | |
to the `mode`. | |
""" | |
template = np.asarray(template) | |
x = np.asarray(x) | |
assert template.ndim == x.ndim | |
if not assume_centered: | |
template = template - template.mean() | |
x = x - x.mean() | |
a1 = np.ones_like(template) | |
corr = correlate(x, template, mode=mode, method=method) | |
bb_a1 = convolve(np.square(x), a1, mode=mode, method=method) | |
b_a1 = convolve(x, a1, mode=mode, method=method) | |
x = bb_a1 - np.square(b_a1) / template.size | |
x[x < 0] = 0 | |
corr /= np.sqrt(x * np.square(template).sum()) | |
corr[~np.isfinite(corr)] = 0 | |
return corr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.signal import correlation_lags | |
from normxcorr import normxcorr | |
rg = np.random.default_rng(0) | |
# -- 2D test example with pad mode SAME | |
# odd-sized 2D test image | |
D = 101 | |
a = rg.standard_normal((D, D)) | |
a1 = np.ones_like(a) | |
a_lags0 = correlation_lags(a.shape[0], a.shape[0], mode="same") | |
a_lags1 = correlation_lags(a.shape[1], a.shape[1], mode="same") | |
c_aa = normxcorr(a, a, mode="same") | |
# check mode SAME working | |
assert(c_aa.shape == a.shape) | |
# correlations are finite and not NaN | |
assert(np.isfinite(c_aa).all()) | |
# correlations are <= 1 | |
# could be violated by machine precision | |
assert((np.abs(c_aa) <= 1).all()) | |
# correlation at 0 lag is 1 | |
assert(np.isclose(c_aa[D // 2, D // 2], 1.0)) | |
# max correlation happens at 0 lag (whp) | |
l0, l1 = np.unravel_index(c_aa.argmax(), a.shape) | |
assert(a_lags0[l0] == 0) | |
assert(a_lags1[l1] == 0) | |
# -- 2D test example with pad mode FULL | |
# another odd-sized 2D test image | |
D = 101 | |
a = rg.standard_normal((D, D)) | |
a1 = np.ones_like(a) | |
a_lags0 = correlation_lags(a.shape[0], a.shape[0], mode="full") | |
a_lags1 = correlation_lags(a.shape[1], a.shape[1], mode="full") | |
c_aa = normxcorr(a, a, mode="full") | |
# check mode FULL working | |
c_aa_shape = np.array(c_aa.shape) | |
a_shape = np.array(a.shape) | |
assert((c_aa_shape == a_shape + 2 * (D // 2)).all()) | |
c_aa_D = D + 2 * (D // 2) | |
# correlations are finite and not NaN | |
assert(np.isfinite(c_aa).all()) | |
# correlations are <= 1 | |
# could be violated by machine precision | |
assert((np.abs(c_aa) <= 1).all()) | |
# correlation at 0 lag is 1 | |
assert(np.isclose(c_aa[c_aa_D // 2, c_aa_D // 2], 1.0)) | |
# max correlation happens at 0 lag (whp) | |
l0, l1 = np.unravel_index(c_aa.argmax(), c_aa.shape) | |
assert(a_lags0[l0] == 0) | |
assert(a_lags1[l1] == 0) | |
# -- FFT and direct have same result | |
# small test image so direct calc is fast | |
a = rg.standard_normal((11, 11)) | |
c_fft = normxcorr(a, a, mode="same", method="fft") | |
c_direct = normxcorr(a, a, mode="same", method="direct") | |
assert(np.isclose(c_fft, c_direct).all()) | |
c_fft = normxcorr(a, a, mode="full", method="fft") | |
c_direct = normxcorr(a, a, mode="full", method="direct") | |
assert(np.isclose(c_fft, c_direct).all()) | |
c_fft = normxcorr(a, a, mode="valid", method="fft") | |
c_direct = normxcorr(a, a, mode="valid", method="direct") | |
assert(np.isclose(c_fft, c_direct).all()) | |
# -- 1D test example with pad mode SAME | |
# odd-sized 1D test image | |
D = 101 | |
a = rg.standard_normal(D) | |
a1 = np.ones_like(a) | |
a_lags = correlation_lags(a.shape[0], a.shape[0], mode="same") | |
c_aa = normxcorr(a, a, mode="same") | |
# check mode SAME working | |
assert(c_aa.shape == a.shape) | |
# correlations are finite and not NaN | |
assert(np.isfinite(c_aa).all()) | |
# correlations are <= 1 | |
# could be violated by machine precision | |
assert((np.abs(c_aa) <= 1).all()) | |
# correlation at 0 lag is 1 | |
assert(np.isclose(c_aa[D // 2], 1.0)) | |
# max correlation happens at 0 lag (whp) | |
assert(a_lags[c_aa.argmax()] == 0) | |
# -- 4D test example with pad mode SAME | |
# odd-sized 4D test image | |
D = 11 | |
a = rg.standard_normal([D] * 4) | |
a1 = np.ones_like(a) | |
a_lags = [ | |
correlation_lags(a.shape[i], a.shape[i], mode="same") | |
for i in range(4) | |
] | |
c_aa = normxcorr(a, a, mode="same") | |
# check mode SAME working | |
assert(c_aa.shape == a.shape) | |
# correlations are finite and not NaN | |
assert(np.isfinite(c_aa).all()) | |
# correlations are <= 1 | |
# could be violated by machine precision | |
assert((np.abs(c_aa) <= 1).all()) | |
# correlation at 0 lag is 1 | |
assert(np.isclose(c_aa[D // 2, D // 2, D // 2, D // 2], 1.0)) | |
# max correlation happens at 0 lag (whp) | |
inds = np.unravel_index(c_aa.argmax(), a.shape) | |
assert(all(lags[ind] == 0 for lags, ind in zip(a_lags, inds))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment