Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Created February 6, 2019 23:09
Show Gist options
  • Save mrdrozdov/e4e5cef4201fbee0fc7f13dc132564e2 to your computer and use it in GitHub Desktop.
Save mrdrozdov/e4e5cef4201fbee0fc7f13dc132564e2 to your computer and use it in GitHub Desktop.
kalpesh-intersect.py
"""
I have two tensors, a "reference" tensor of size (8192,) and a "query"
tensor of size (10000, 500). I want an binary mask of size (10000, 8192)
with 1s in all indices (i, j) where reference[ j ] exists in query[ i , : ].
Also, all values are integers and upper bounded by another integer.
"""
import torch
seed = 10
vocabsize = 300 # value isn't important, but it's more than refsize
refsize = (100,)
querysize = (10, 5)
torch.manual_seed(10)
reference = torch.randperm(vocabsize)[:refsize[0]].long()
query = torch.zeros(*querysize).long()
for i in range(query.shape[0]):
query[i] = torch.randperm(vocabsize)[:querysize[1]]
print('reference:')
print(reference)
print('query:')
print(query)
print()
# Method 1: Get binary mask of shape (number of querys, size of query)
print('METHOD1')
vocab = torch.zeros(vocabsize).long()
vocab[reference] = 1
print('vocab:')
print(vocab)
querymask = vocab[query.view(-1)].view(*querysize)
print('querymask:')
print(querymask)
print()
# Method 2: Get binary mask of shape (number of querys, size of reference)
print('METHOD2')
# Create mapping of integer -> position in reference.
lookup = torch.LongTensor(vocabsize).fill_(reference.shape[0]+1)
index = reference.argsort()
lookup[reference] = index
print('lookup:')
print(lookup)
itensor = torch.arange(query.shape[0]).view(-1, 1).expand(*query.shape)[querymask.byte()]
print('itensor:')
print(itensor)
jtensor = lookup[query[querymask.byte()]]
print('jtensor:')
print(jtensor)
bigquerymask = torch.zeros(query.shape[0], reference.shape[0])
bigquerymask[itensor, jtensor] = 1
print('bigquerymask:')
print(bigquerymask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment