-
-
Save sergey-kras/524d4992bccd5ef45a859749abf4b86c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Загружаем обученную модель и токенизатор | |
model = BertForSequenceClassification.from_pretrained('./trained_model') | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased') | |
# Реальный текст для проверки | |
test_text = ["Your test text here"] | |
# Токенизация входного текста | |
inputs = tokenizer(test_text, truncation=True, padding=True, return_tensors="pt") | |
# Прогнозирование | |
model.eval() # устанавливаем модель в режим оценки (это важно для корректного прогнозирования) | |
with torch.no_grad(): # выключаем вычисление градиентов (это важно для экономии памяти) | |
outputs = model(**inputs) | |
# Получаем вероятности через softmax | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# выбираем класс с наибольшей вероятностью: | |
_, predicted_class = torch.max(probs, dim=-1) | |
# преобразуем предсказанный класс в True/False | |
predicted_class = predicted_class.item() # это преобразует тензор в обычное целое число | |
predicted_class_bool = bool(predicted_class) # преобразуем 1 в True и 0 в False | |
print(predicted_class_bool) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment