Skip to content

Instantly share code, notes, and snippets.

View ducha-aiki's full-sized avatar

Dmytro Mishkin ducha-aiki

View GitHub Profile
@ducha-aiki
ducha-aiki / compact_bilinear_pooling.py
Created February 1, 2019 11:41 — forked from vadimkantorov/compact_bilinear_pooling.py
Compact Bilinear Pooling in PyTorch using the new FFT support
import torch
class CompactBilinearPooling(torch.nn.Module):
def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
super(CompactBilinearPooling, self).__init__()
self.output_dim = output_dim
self.sum_pool = sum_pool
generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim))
self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim))