https://github.com/chenhsiu/remagic/blob/master/convolution.ipynb
The Overlap add/save method gives us an idea about how to use FFT to accelerate convolution. This method is generally much faster than typical pair-wise multiplication convolution by its definition. But how many performance gain we can get from this kind of FFT accelerated convolution? Let's do some experiments on the 2D convolution and see its result.
Convolution Theorem
First of all, let's try prove the convolution theorem with the python code.
From time domain
In [32]:
import numpy as np
from scipy.fftpack import fft, ifft, fftn, ifftn
x = [0, 0, 3, -1, 0]
h = [1, 1, 2, 1, 1]
X = fft(x)
H = fft(h)
r1 = np.real(ifft(X*H))
r2 = np.convolve(np.hstack((x[1:],x)), h, mode='valid')
print('%s == %s ? %s' % (r1, r2, 'Yes' if np.allclose(r1, r2) else 'No'))
From frequency domain
In [31]:
# frequency domain
X = np.array([1.+1.j, 2.+3.j, 1.+2.j, 3.+2.j])
H = np.array([2.+3.j, 1.+1.j, 3.+3.j, 4.+5.j])
Y = X * H
r1 = ifft(Y)
print('r1 = %s' % r1)
x = ifft(X)
h = ifft(H)
x = np.hstack((x[1:], x)) # by default is linear convolution, make it circular
r2 = np.convolve(x,h, mode='valid')
print('r2 = %s' % r2)
print('r1 == r2 ? %s' % ('Yes' if np.allclose(r1, r2) else 'No'))
Time domain with different size signal
In [16]:
x = [7, 2, 3, -1, 0, -3, 5, 6]
h = [1, 2, -1]
X = fft(x)
H = fft(h, len(x))
r1 = ifft(X * H)
print('r1 = %s' % r1)
r2 = np.real(np.convolve(np.hstack((x[1:],x)),h,mode='valid'))
r2 = r2[len(r2) - len(x):]
print('r2 = %s' % r2)
print('r1 == r2 ? %s' % ('Yes' if np.allclose(r1, r2) else 'No'))
Fast Convolution with FFT
Now we see how the FFT can help us on fast convolution.
The convolve2d in scipy.signal uses pair-wised multiplication (see the source code ). Meanwhile, there is also a fftconvolve in scipy.signal which uses FFT to calculate convolution (see source code here). From its documentation:
This is generally much faster than convolve for large arrays (n > ~500), but can be slower when only a few output values are needed, and can only output float arrays (int or object array inputs will be cast to float).
For overlapadd2, we found a 2D overlap-add with FFT implementation on github. Below is its description:
Fast two-dimensional linear convolution via the overlap-add method. The overlap-add method is well-suited to convolving a very large array,Amat
, with a much smaller filter array,Hmat
by breaking the large convolution into many smallerL
-sized sub-convolutions, and evaluating these using the FFT. The computational savings over the straightforward two-dimensional convolution via, say, scipy.signal.convolve2d, can be substantial for large Amat and/or Hmat.
The performance comparison result on my NB shows below:
method | convolve2d | fftconvolve | overlapadd2 |
---|---|---|---|
speed | 3040 ms | 30.1 ms | 94.8 ms |
We can see the FFT based convolution is generally much faster than typical convolution, from 30X to 100X acceleration. The result surprises me a bit because fftconvolve is still faster than overlapadd2. The overlapadd2 looks good and ideal, but the user still needs to tweak the size of L in order to get the best performance. Maybe overlapadd2 has the real benefits only when the input matrix is so big that can't be fit into memory and we have split into sub-convolutions.
One thing to note is, when we use FFT convolution on image processing, there will be a dark borders around the image, due to the zero-padding beyond its boundaries. The convolve2d function allows for other types of image boundaries, but is far slower.
Reference
There is an article doing 2D convolution benchmark with various convolution libraries:
In [75]:
# before you run, eval the cell containing overlapadd2 at the end
from scipy import misc
import scipy.signal as sp
import matplotlib.pyplot as plt
A = misc.ascent()
A = A.astype(float)
print(A.shape)
H = np.outer(sp.gaussian(64, 8), sp.gaussian(64, 8))
print('==> using convolve2d')
%time B1 = sp.convolve2d(A, H, mode='same')
print('==> using fftconvolve')
%time B2 = sp.fftconvolve(A, H, mode='same')
print('==> using overlapadd2')
%time B3 = overlapadd2(A, H)
fig, (ax_orig, ax_conv, ax_fft2conv, ax_ovadd2) = plt.subplots(1, 4, figsize = (12, 8))
ax_orig.imshow(A, cmap='gray')
ax_orig.set_title('Original')
ax_orig.set_axis_off()
ax_conv.imshow(B1, cmap='gray')
ax_conv.set_title('convolve2d')
ax_conv.set_axis_off()
ax_fft2conv.imshow(B2, cmap='gray')
ax_fft2conv.set_title('fftconvolve')
ax_fft2conv.set_axis_off()
ax_ovadd2.imshow(B3, cmap='gray')
ax_ovadd2.set_title('overlapadd2')
ax_ovadd2.set_axis_off()
fig.show()