Last active
July 23, 2020 12:53
-
-
Save ab3llini/dcbd4833dbd78aa6d2194b740b62f03a to your computer and use it in GitHub Desktop.
Inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import os | |
from transformers import BertForSequenceClassification, AutoTokenizer | |
def extract_sentiment(model, tokenizer, text, device): | |
# Encode the text, create a tensor and move to device. | |
tensor = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0).long().to(device) | |
with torch.no_grad(): | |
output = model( | |
input_ids=tensor | |
) | |
logits, = nn.functional.softmax(output[0], dim=1) | |
negative, neutral, positive = logits | |
print( | |
'Sentiment for "{}" : Negative: {:.3f}% | Neutral: {:.3f}% | Positive: {:.3f}%'.format( | |
text, | |
negative.item() * 100.0, | |
neutral.item() * 100.0, | |
positive.item() * 100.0 | |
) | |
) | |
if __name__ == '__main__': | |
# Checkpoint path | |
model_chkpt_dir = 'path/to/your/checkpoint/dir' | |
# The training procedure automatically saves the model with a huggingface compatible interface | |
path = os.path.join(model_chkpt_dir, 'huggingface') | |
model = BertForSequenceClassification.from_pretrained(path) | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
# Set the inference device | |
device = torch.device('cuda') | |
# Load the model on the right device | |
model.to(device) | |
# Set the model in evaluation mode | |
model.eval() | |
while True: | |
text = input('Model input < ') | |
extract_sentiment(model, tokenizer, text, device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment