Last active
September 22, 2021 07:51
-
-
Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.
Compact Bilinear Pooling in PyTorch using the new FFT support
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# References: | |
# [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847 | |
# [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 | |
# [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf | |
# [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060 | |
# [5] Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling | |
# TODO: migrate to use of new native complex64 types | |
# TODO: change strided x coo matmul to torch.matmul(): M[sparse_coo] @ M[strided] -> M[strided] | |
import torch | |
class CompactBilinearPooling(torch.nn.Module): | |
def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): | |
super().__init__() | |
self.out_channels = out_channels | |
self.sum_pool = sum_pool | |
generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_channels), rand_h]), rand_s, [in_channels, out_channels]).to_dense() | |
self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,), dtype = torch.float32) - 1, in_channels1, out_channels), requires_grad = False) | |
self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,), dtype = torch.float32) - 1, in_channels2, out_channels), requires_grad = False) | |
def forward(self, x1, x2): | |
fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1) | |
fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1) | |
# torch.rfft does not support yet torch.complex64 outputs, so we do complex product manually | |
fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) | |
cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.out_channels, )) * self.out_channels | |
return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2) |
hello,
torch.stack([torch.arange(in_features), rand_h])
where in_features is not defined. How to fix it?
thanks!
Thanks for noting this. Fixed! It should have been in_channels
Some ways to improve the code: make use of the new PyTorch fft module, complex support. Figure out dense x sparse matmul (currently I'm materializing the sparse sketch)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for your code first. I have a question that in the other implements, like Torch version and Tensorflow version, there is a zero_padding before feeding the tensor into the fft. But in this code, I don't see the zero_padding.
Thanks very much!