Last active
December 17, 2020 05:24
-
-
Save andiac/cb5d6b6480353d605d85f7a845ccbcdc to your computer and use it in GitHub Desktop.
TextCNN pytorch implementation, conv1d vs. conv2d
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| batch_size = 3 | |
| seq_len = 20 | |
| vocab_size = 4 | |
| emb_size = 6 | |
| conv_length = 2 | |
| num_conv_kernel = 30 | |
| channel_in = 1 # of course for text... | |
| # tensor([[1, 3, 0, 2, 2, 3, 1, 2, 0, 3, 2, 1, 1, 2, 2, 0, 2, 0, 0, 1], | |
| # [1, 1, 3, 3, 0, 2, 3, 1, 0, 0, 2, 1, 0, 0, 3, 2, 1, 2, 1, 1], | |
| # [1, 3, 3, 0, 1, 0, 0, 1, 1, 0, 2, 3, 2, 0, 3, 2, 1, 2, 1, 0]]) | |
| x = torch.randint(0, vocab_size, (batch_size, seq_len)) | |
| emb = torch.nn.Embedding(vocab_size, emb_size) | |
| # weight's shape: (num_conv_kernel, channel_in, conv_length, emb_size), bias' shape: (num_conv_kernel) | |
| conv2d = torch.nn.Conv2d(channel_in, num_conv_kernel, (conv_length, emb_size)) | |
| # weight's shape: (num_conv_kernel, emb_size, conv_length), bias' shape: (num_conv_kernel) | |
| conv1d = torch.nn.Conv1d(emb_size, num_conv_kernel, conv_length) | |
| # same initialization | |
| conv1dsd = conv1d.state_dict() | |
| conv1dsd["bias"] = conv2d.state_dict()["bias"] | |
| conv1dsd["weight"] = conv2d.state_dict()["weight"].squeeze(1).permute(0,2,1) | |
| conv1d.load_state_dict(conv1dsd) | |
| # part of textCNN, we just compare the features before max pooling... | |
| tens_a = conv2d(emb(x).unsqueeze(1)).squeeze(3) | |
| tens_b = conv1d(emb(x).permute(0, 2, 1)) | |
| # True | |
| print(torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-6))) | |
| # False | |
| print(torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-9))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment