Created
January 27, 2021 12:27
-
-
Save thomasaarholt/3d4b7b27a63bbb51f2138414e882381f to your computer and use it in GitHub Desktop.
2D Convolution Using FFT and Scipy for even and odd-sized arrays
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.signal import convolve2d, gaussian | |
from scipy.misc import face | |
import matplotlib.pyplot as plt | |
def convolve2d_fft(arr1, arr2): | |
s0, s1 = arr1.shape | |
conv = np.fft.irfft2( | |
np.fft.rfft2(arr1) * np.fft.rfft2(arr2), | |
s=arr1.shape) | |
conv = np.roll( | |
conv, | |
( | |
-(s0 - 1 - (s0 + 1) % 2) // 2, | |
-(s1 - 1 - (s1 + 1) % 2) // 2, | |
), | |
axis=(0, 1)) | |
return conv | |
def gaussian_kernel(shape, std, normalised=False): | |
'''Generates a n x n matrix with a centered gaussian | |
of standard deviation std centered on it. If normalised, | |
its volume equals 1. | |
''' | |
s0, s1 = shape | |
gaussian1D1 = gaussian(s0, std) | |
if s0 == s1: | |
gaussian1D2 = gaussian1D1 | |
else: | |
gaussian1D2 = gaussian(s1, std) | |
gaussian2D = np.outer(gaussian1D1, gaussian1D2) | |
if normalised: | |
gaussian2D /= (2*np.pi*(std**2)) | |
return gaussian2D | |
fig, AX = plt.subplots(ncols=3, nrows=2) | |
(ax1, ax2, ax3, ax4, ax5, ax6) = AX.flatten() | |
### Making an even-shaped array | |
A = face(True)[::16, ::16][:-1,:-1] | |
B = gaussian_kernel(A.shape, 0.5, True) | |
C = convolve2d(A, B, mode='same', boundary='wrap') | |
D_even = convolve2d_fft(A, B) | |
ax1.imshow(C) | |
ax1.set(title='Scipy') | |
ax2.imshow(D_even) | |
ax2.set(title='FFT') | |
ax3.imshow(C-D_even) | |
ax3.set(title='Even-sized\ndifference') | |
### Making an odd-shaped array | |
A = A[:-1,:-1] | |
B = gaussian_kernel(A.shape, 0.5, True) | |
C = convolve2d(A, B, mode='same', boundary='wrap') | |
D_odd = convolve2d_fft(A, B) | |
ax4.imshow(C) | |
ax4.set(title='Scipy') | |
ax5.imshow(D_odd) | |
ax5.set(title='FFT') | |
ax6.imshow(C-D_odd) | |
ax6.set(title='Odd-sized\ndifference') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment