Created
November 11, 2019 14:35
-
-
Save mtreviso/69b88cccc8ec95cc032a09fc97264ace to your computer and use it in GitHub Desktop.
Implementation of a sparse Bag of Words (BoW) in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def create_bow(words, vocab_size, pad_id=None): | |
""" | |
Create a bag of words matrix using torch.sparse.FloatTensor. | |
Args: | |
words (torch.LongTensor): tensor containing ids for words in | |
your vocabulary. Shape of (batch_size, seq_len) | |
vocab_size (int): size of the words vocabulary (including special | |
symbols like <unk>, <pad>, etc) | |
pad_id (int): the id that corresponds to the pad token. The count | |
related to pad_id will be made zero if pad_id is not None. | |
pad_id should be a valid id, i.e., 0 <= pad_id < vocab_size. | |
Default is None. | |
Returns: | |
torch.Tensor of shape (batch_size, vocab_size) | |
""" | |
batch_size, seq_len = words.shape | |
if pad_id is not None: | |
assert 0 <= pad_id < vocab_size | |
vals = torch.ne(words, pad_id).int() | |
else: | |
# dummy mask, all values equal to 1 | |
vals = torch.ones(batch_size, seq_len) | |
bids = torch.arange(batch_size).to(words.device) | |
bids = bids.unsqueeze(-1).expand(-1, seq_len).flatten() | |
idxs = torch.stack((bids, words.flatten()), dim=0) | |
vals = vals.to(words.device).float().flatten() | |
size = torch.Size([batch_size, vocab_size]) | |
bow = torch.sparse.FloatTensor(idxs, vals, size) | |
bow = bow.to_dense().to(words.device) | |
return bow |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment