Created
February 6, 2019 23:09
-
-
Save mrdrozdov/e4e5cef4201fbee0fc7f13dc132564e2 to your computer and use it in GitHub Desktop.
kalpesh-intersect.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
I have two tensors, a "reference" tensor of size (8192,) and a "query" | |
tensor of size (10000, 500). I want an binary mask of size (10000, 8192) | |
with 1s in all indices (i, j) where reference[ j ] exists in query[ i , : ]. | |
Also, all values are integers and upper bounded by another integer. | |
""" | |
import torch | |
seed = 10 | |
vocabsize = 300 # value isn't important, but it's more than refsize | |
refsize = (100,) | |
querysize = (10, 5) | |
torch.manual_seed(10) | |
reference = torch.randperm(vocabsize)[:refsize[0]].long() | |
query = torch.zeros(*querysize).long() | |
for i in range(query.shape[0]): | |
query[i] = torch.randperm(vocabsize)[:querysize[1]] | |
print('reference:') | |
print(reference) | |
print('query:') | |
print(query) | |
print() | |
# Method 1: Get binary mask of shape (number of querys, size of query) | |
print('METHOD1') | |
vocab = torch.zeros(vocabsize).long() | |
vocab[reference] = 1 | |
print('vocab:') | |
print(vocab) | |
querymask = vocab[query.view(-1)].view(*querysize) | |
print('querymask:') | |
print(querymask) | |
print() | |
# Method 2: Get binary mask of shape (number of querys, size of reference) | |
print('METHOD2') | |
# Create mapping of integer -> position in reference. | |
lookup = torch.LongTensor(vocabsize).fill_(reference.shape[0]+1) | |
index = reference.argsort() | |
lookup[reference] = index | |
print('lookup:') | |
print(lookup) | |
itensor = torch.arange(query.shape[0]).view(-1, 1).expand(*query.shape)[querymask.byte()] | |
print('itensor:') | |
print(itensor) | |
jtensor = lookup[query[querymask.byte()]] | |
print('jtensor:') | |
print(jtensor) | |
bigquerymask = torch.zeros(query.shape[0], reference.shape[0]) | |
bigquerymask[itensor, jtensor] = 1 | |
print('bigquerymask:') | |
print(bigquerymask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment