Created
September 4, 2019 13:13
-
-
Save soumith/facf85f80e02ea57c61c1e1318c3befe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### First, tokenize the input | |
import torch | |
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased', do_basic_tokenize=False) | |
text_1 = "Who was Jim Henson ?" | |
text_2 = "Jim Henson was a puppeteer" | |
# Tokenized input | |
indexed_tokens = tokenizer.encode(text_1, text_2, add_special_tokens=True) | |
### Get the hidden states computed by `BertModel` | |
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper) | |
segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] | |
# Convert inputs to PyTorch tensors | |
segments_tensors = torch.tensor([segments_ids]) | |
tokens_tensor = torch.tensor([indexed_tokens]) | |
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased') | |
with torch.no_grad(): | |
encoded_layers, _ = model(tokens_tensor, segments_tensors) | |
### Predict masked tokens using `bertForMaskedLM` | |
# Mask a token that we will try to predict back with `BertForMaskedLM` | |
masked_index = 8 | |
tokenized_text[masked_index] = '[MASK]' | |
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) | |
tokens_tensor = torch.tensor([indexed_tokens]) | |
maskedLM_model = torch.hub.load('huggingface/pytorch-transformers', 'modelWithLMHead', 'bert-base-cased') | |
with torch.no_grad(): | |
predictions = maskedLM_model(tokens_tensor, segments_tensors) | |
# Get the predicted token | |
predicted_index = torch.argmax(predictions[0][0])[masked_index].item() | |
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] | |
assert predicted_token == 'Jim' | |
### Question answering using `BertForQuestionAnswering` | |
questionAnswering_model = torch.hub.load('huggingface/pytorch-transformers', 'modelForQuestionAnswering', 'bert-base-cased') | |
# Predict the start and end positions logits | |
with torch.no_grad(): | |
start_logits, end_logits = questionAnswering_model(tokens_tensor, segments_tensors) | |
# Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions (set model to train mode before if used for training) | |
start_positions, end_positions = torch.tensor([12]), torch.tensor([14]) | |
multiple_choice_loss = questionAnswering_model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions) | |
### Classify sequence using `BertForSequenceClassification` | |
seqClassification_model = torch.hub.load('huggingface/pytorch-transformers', 'modelForSequenceClassification', 'bert-base-cased', num_labels=2) | |
# Predict the sequence classification logits | |
with torch.no_grad(): | |
seq_classif_logits = seqClassification_model(tokens_tensor, segments_tensors) | |
# Or get the sequence classification loss (set model to train mode before if used for training) | |
labels = torch.tensor([1]) | |
seq_classif_loss = seqClassification_model(tokens_tensor, segments_tensors, labels=labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment