Last active
September 12, 2024 09:30
-
-
Save xvdp/77d25acfb20a49e44b89f80f0fa2f7c2 to your computer and use it in GitHub Desktop.
Fourier 1d Conv is constant speed as kernel support inclreases
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""@xvdp | |
Fourier 1D convs are constant speed with support size increase. HW optimized Sliding W are faster for small supports <20. | |
Vec (conv) kernel = ifft(fft(Vec) * fft(kernel)) | |
I used these to RIR (Room Impusle Response) to audio augmentation. | |
I filed issue to pytorch https://github.com/pytorch/pytorch/issues/79222, I just noticed I had not gisted it. | |
""" | |
from typing import Optional | |
import time | |
import torch | |
from torch import Tensor | |
import torch.nn.functional as F | |
# pylint: disable=no-member | |
# pylint: disable=suppressed-message | |
def fftconv1d(x: Tensor, weight: Tensor, | |
bias: Optional[Tensor] = None, | |
padding: int = 0, | |
groups: int = 1) -> Tensor: | |
""" | |
Args | |
x: Tensor (batch_size, in_channels, size) | |
weight: Tensor (out_channels, in_channels//groups, kernel_size) | |
bias: Tensor [None] out_channels | |
padding int [0] | |
groups int [1] in_channels, out _channels must be divisible by groups | |
# stride and dilation = 1 | |
adapted from https://towardsdatascience.com/fourier-convolutions-in-pytorch-4cbd23c70005 | |
faster for large ones | |
""" | |
assert x.ndim == 3, "x expedted shape: (N, C, L)" | |
assert weight.ndim == 3, "weight expected (in_channels, out_channels, kernel)" | |
_out, _in, _ = weight.shape | |
if bias is not None: | |
assert bias.ndim==1 and len(bias) == _out, "bias vector sized as out_channels reqd" | |
assert not x.shape[1]%groups, f"in_channels must be mod groups {x.shape[1], groups}" | |
assert not _out%groups, f"out_channels must be mod groups {_out, groups}" | |
assert x.shape[1] == groups*_in, f"Given groups={groups} and weight {tuple(weight.shape)}, \ | |
expected input {tuple(x.shape)} to have {groups*_in} channels" | |
out = F.pad(x, [padding, padding]) | |
_pad = out.shape[-1] - weight.shape[-1] | |
x_rfft = torch.fft.rfftn(out, dim=-1) | |
w_rfft = torch.fft.rfftn(F.pad(weight, (0, _pad)), dim=-1) | |
w_rfft.imag *= -1 | |
if groups == 1: | |
x_rfft = torch.einsum("ab..., cb... -> ac...", x_rfft, w_rfft) | |
else: | |
_o = _out//groups | |
x_rfft = torch.cat([torch.einsum("ab..., cb... -> ac...", | |
x_rfft[:, _in*g:_in*(g+1)], | |
w_rfft[_o*g:_o*(g+1)]) | |
for g in range(groups)], dim=1) | |
out = torch.fft.irfftn(x_rfft, dim=-1)[..., :_pad + 1].contiguous() | |
if bias is not None: | |
out = out + bias.view(1, -1, 1) | |
return out | |
def _testconv(cuda=True, grad=True, pad=None, out_channels=4, in_channels=2, | |
batch_size= 20, size = 4096, ksize = 1000, groups=1): | |
if pad is None: | |
pad = ksize//2 | |
signal = torch.randn(batch_size, in_channels, size) | |
if grad: | |
signal.requires_grad = True | |
kernel = torch.randn(out_channels, in_channels//groups, ksize) | |
bias = torch.randn(out_channels) | |
print(f"\n signal: {tuple(signal.shape)}, kernel: {tuple(kernel.shape)}") | |
if cuda: | |
signal = signal.to(device="cuda") | |
kernel = kernel.to(device="cuda") | |
bias = bias.to(device="cuda") | |
_start = time.time() | |
y0 = F.conv1d(signal, kernel, bias=bias, padding=pad, groups=groups) | |
if cuda: | |
torch.cuda.synchronize() | |
_fconv = time.time() | |
y2 = fftconv1d(signal, kernel, bias=bias, padding=pad, groups=groups) | |
if cuda: | |
torch.cuda.synchronize() | |
_fftconv = time.time() | |
_test = f'test: cuda:{cuda}, grad:{grad}, pad{pad}, out:{out_channels}, in{in_channels}, groups{groups}' | |
print(_test) | |
_nntime = 1000*(_fconv - _start) | |
_fftime = 1000*(_fftconv - _fconv) | |
if _nntime < _fftime: | |
_nn="\t\t\tnn.Conv1d is faster" | |
_ff ="" | |
elif _fftime < _nntime: | |
_nn = "" | |
_ff = "\t\t\tFFT faster" | |
print(f" nn.Conv1d() time {1000*(_nntime):.1f} ms {_nn}") | |
print(f" fftconv1d time {1000*(_fftime):.1f} ms {_ff}") | |
assert torch.allclose(y0, y2, rtol=1e-3, atol=1e-3), _test | |
def test_conv_opt(): | |
cuda = [True, False] | |
grad = [True, False] | |
padding = [0, None, 100] | |
groups = [1,2] | |
out_channels = [4,2] | |
in_channels = [2,8] | |
batch_size = 20 | |
size = [4096, 14400] | |
ksize = [9, 1000] | |
for p in padding: | |
for r in grad: | |
for c in cuda: | |
for g in groups: | |
for i in in_channels: | |
for o in out_channels: | |
for k in ksize: | |
for s in size: | |
_testconv(cuda=c, grad=r, pad=p, out_channels=o, groups=g, | |
in_channels=i, batch_size=batch_size, size=s, ksize=k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment