Skip to content

Instantly share code, notes, and snippets.

@andiac
Last active December 17, 2020 05:24
Show Gist options
  • Select an option

  • Save andiac/cb5d6b6480353d605d85f7a845ccbcdc to your computer and use it in GitHub Desktop.

Select an option

Save andiac/cb5d6b6480353d605d85f7a845ccbcdc to your computer and use it in GitHub Desktop.
TextCNN pytorch implementation, conv1d vs. conv2d
import torch
batch_size = 3
seq_len = 20
vocab_size = 4
emb_size = 6
conv_length = 2
num_conv_kernel = 30
channel_in = 1 # of course for text...
# tensor([[1, 3, 0, 2, 2, 3, 1, 2, 0, 3, 2, 1, 1, 2, 2, 0, 2, 0, 0, 1],
# [1, 1, 3, 3, 0, 2, 3, 1, 0, 0, 2, 1, 0, 0, 3, 2, 1, 2, 1, 1],
# [1, 3, 3, 0, 1, 0, 0, 1, 1, 0, 2, 3, 2, 0, 3, 2, 1, 2, 1, 0]])
x = torch.randint(0, vocab_size, (batch_size, seq_len))
emb = torch.nn.Embedding(vocab_size, emb_size)
# weight's shape: (num_conv_kernel, channel_in, conv_length, emb_size), bias' shape: (num_conv_kernel)
conv2d = torch.nn.Conv2d(channel_in, num_conv_kernel, (conv_length, emb_size))
# weight's shape: (num_conv_kernel, emb_size, conv_length), bias' shape: (num_conv_kernel)
conv1d = torch.nn.Conv1d(emb_size, num_conv_kernel, conv_length)
# same initialization
conv1dsd = conv1d.state_dict()
conv1dsd["bias"] = conv2d.state_dict()["bias"]
conv1dsd["weight"] = conv2d.state_dict()["weight"].squeeze(1).permute(0,2,1)
conv1d.load_state_dict(conv1dsd)
# part of textCNN, we just compare the features before max pooling...
tens_a = conv2d(emb(x).unsqueeze(1)).squeeze(3)
tens_b = conv1d(emb(x).permute(0, 2, 1))
# True
print(torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-6)))
# False
print(torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-9)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment