Skip to content

Instantly share code, notes, and snippets.

@sergey-kras
Created July 5, 2023 12:02
Show Gist options
  • Save sergey-kras/524d4992bccd5ef45a859749abf4b86c to your computer and use it in GitHub Desktop.
Save sergey-kras/524d4992bccd5ef45a859749abf4b86c to your computer and use it in GitHub Desktop.
# Загружаем обученную модель и токенизатор
model = BertForSequenceClassification.from_pretrained('./trained_model')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
# Реальный текст для проверки
test_text = ["Your test text here"]
# Токенизация входного текста
inputs = tokenizer(test_text, truncation=True, padding=True, return_tensors="pt")
# Прогнозирование
model.eval() # устанавливаем модель в режим оценки (это важно для корректного прогнозирования)
with torch.no_grad(): # выключаем вычисление градиентов (это важно для экономии памяти)
outputs = model(**inputs)
# Получаем вероятности через softmax
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# выбираем класс с наибольшей вероятностью:
_, predicted_class = torch.max(probs, dim=-1)
# преобразуем предсказанный класс в True/False
predicted_class = predicted_class.item() # это преобразует тензор в обычное целое число
predicted_class_bool = bool(predicted_class) # преобразуем 1 в True и 0 в False
print(predicted_class_bool)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment