Created
August 30, 2019 21:05
-
-
Save Lexie88rus/0d7cc42c329ca50ec96b52907d1d2b6e to your computer and use it in GitHub Desktop.
Function to convert an index from the vocabulary into tensor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define a function to convert tensor into index in vocabulary | |
def indexFromTensor(target): | |
''' | |
Function returns tensor containing target index given tensor representing target word | |
''' | |
top_n, top_i = target.topk(1) | |
target_index = top_i[0].item() | |
target_index_tensor = torch.zeros((1), dtype = torch.long) | |
target_index_tensor[0] = target_index | |
return target_index_tensor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment