Skip to content

Instantly share code, notes, and snippets.

@Felflare
Created February 10, 2020 02:37
Show Gist options
  • Save Felflare/663d7dfbdee13b029908932a5e3a0cbe to your computer and use it in GitHub Desktop.
Save Felflare/663d7dfbdee13b029908932a5e3a0cbe to your computer and use it in GitHub Desktop.
Sample function to generate text from XLNet model implemented by huggingface.
from transformers import XLNetTokenizer, XLNetLMHeadModel
import torch
import torch.nn.functional as F
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
# We show how to setup inputs to predict a next token using a bi-directional context.
encoded_text = tokenizer.encode("Quick brown fox jumped over the lazy <mask>.", add_special_tokens=True)
input_ids = torch.tensor(encoded_text).unsqueeze(0) # We will predict the masked token
print(f'Input squence -- {encoded_text}')
# Input squence -- Input squence -- [9928, 3442, 17, 13894, 4651, 95, 18, 17634, 6, 17, 9, 4, 3]
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -5] = 1.0 # Make the "<mask>" token unseen to the model
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, -5] = 1.0 # Define a target to predict, 5th to last token -> the "<mask>" token
outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
print(f'Output shape is -- {next_token_logits.shape}')
# Output shape is -- torch.Size([1, 1, 32000])
temperature = 1
top_p=0.9
filter_value=-float('Inf')
#
next_token_logits = outputs[0][0, -1, :] / temperature
#
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = filter_value
candidate_tokens = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=5) #draw a sample from multinomial distribution of softmax.
print(f'Candidate token ids -- {candidate_tokens}')
# Candidate token ids -- tensor([2288, 29227, 2934, 8002, 303])
print(f'Ids converted to tokens -- {tokenizer.decode(candidate_tokens,skip_special_tokens=True,clean_up_tokenization_spaces=False)}')
# Ids converted to tokens -- dog pony giant pile local
next_token = candidate_tokens[0]
print(f'Top candidate token -- {next_token}')
# Top candidate token -- 2288
output_seq = input_ids.clone()
output_seq[0, -5] = next_token
print(f'Unmasked sequence of Ids -- {output_seq[0].tolist()}')
# Unmasked sequence of Ids -- [9928, 3442, 17, 13894, 4651, 95, 18, 17634, 2288, 17, 9, 4, 3]
tokens = tokenizer.decode(output_seq[0].tolist(), skip_special_tokens=True)
print(f'Unmasked sequence -- {tokens}')
# Unmasked sequence -- Quick brown fox jumped over the lazy dog.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment