Skip to content

Instantly share code, notes, and snippets.

@jgoodie
Created May 29, 2024 04:04
Show Gist options
  • Save jgoodie/9ff4f55d6b331e56986dc482a7f099c1 to your computer and use it in GitHub Desktop.
Save jgoodie/9ff4f55d6b331e56986dc482a7f099c1 to your computer and use it in GitHub Desktop.
torch.manual_seed(101)
# Make predictions
model.eval()
with torch.inference_mode():
y_logits = model(X_test).to(device)
y_preds = torch.softmax(y_logits, dim=1).argmax(dim=1)
accuracy = Accuracy(task="multiclass", num_classes=model.output_features).to(device)
confusion = ConfusionMatrix(task="multiclass", num_classes=model.output_features).to(device)
print("Accuracy: %.2f%%" % (accuracy(y_test, y_preds).item()*100))
print(confusion(y_test, y_preds))
print(classification_report(y_test.cpu(), y_preds.cpu()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment