Last active
March 9, 2019 13:51
-
-
Save dvgodoy/50aa608946ce4a71c55b8f4d7d663100 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Creates instance of extended version of BinaryClassificationMetrics | |
| # using a DataFrame and its probability and label columns, as the output | |
| # from the classifier | |
| bcm = BinaryClassificationMetrics(predictions, scoreCol='probability', labelCol='Survived') | |
| # We still can get the same metrics as the evaluator... | |
| print("Area under ROC Curve: {:.4f}".format(bcm.areaUnderROC)) | |
| print("Area under PR Curve: {:.4f}".format(bcm.areaUnderPR)) | |
| # But now we can PLOT both ROC and PR curves! | |
| fig, axs = plt.subplots(1, 2, figsize=(12, 4)) | |
| bcm.plot_roc_curve(ax=axs[0]) | |
| bcm.plot_pr_curve(ax=axs[1]) | |
| # We can also get all metrics (FPR, Recall and Precision) by threshold | |
| bcm.getMetricsByThreshold().filter('fpr between 0.19 and 0.21').toPandas() | |
| # And get the confusion matrix for any threshold we want | |
| bcm.print_confusion_matrix(.415856) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment