Last active
April 28, 2020 23:19
-
-
Save ledell/91beb929dcdb04a964f5f580faa48a93 to your computer and use it in GitHub Desktop.
Monotonic constraints in H2O AutoML
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example of monotonic constraints in H2O AutoML (using h2o v3.30.0.1) | |
# monotone constraints: http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/algo-params/monotone_constraints.html | |
# H2O AutoML: http://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html | |
library(h2o) | |
h2o.init() | |
# Import the prostate dataset | |
file <- "http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv.zip" | |
prostate <- h2o.importFile(file) | |
# Convert the response column (CAPSULE) to a factor | |
y <- "CAPSULE" | |
prostate[y] <- as.factor(prostate[y]) | |
# Remove response and ID column from predictors | |
x <- setdiff(names(prostate), c(y, "ID")) | |
# Train constrained models (positive constraint on AGE) | |
# Only XGBoost and GBM models can be constrained (and the Stacked Ensembles created from these models) | |
# The Stacked Ensemble models in AutoML are a non-negative linear combination of the base models | |
# The sum of constrained models will also be constrained, therefore, the Stacked Ensemble will also be constrained | |
aml <- h2o.automl(y = y, x = x, | |
training_frame = prostate, | |
monotone_constraints = list(AGE = 1), | |
include_algos = c("XGBoost", "GBM", "StackedEnsemble"), | |
max_models = 10, | |
seed = 1) | |
# Take a quick glance at leaderboard | |
aml@leaderboard | |
# model_id auc logloss | |
# 1 StackedEnsemble_AllModels_AutoML_20200428_121659 0.7664615 0.5741385 0.6647128 | |
# 2 GBM_3_AutoML_20200428_121659 0.7661455 0.6361831 0.6731820 | |
# 3 StackedEnsemble_BestOfFamily_AutoML_20200428_121659 0.7655231 0.5782747 0.6693175 | |
# 4 GBM_2_AutoML_20200428_121659 0.7631223 0.6182583 0.6575571 | |
# 5 GBM_1_AutoML_20200428_121659 0.7583139 0.6051435 0.6603360 | |
# 6 GBM_5_AutoML_20200428_121659 0.7524978 0.6816796 0.6454254 | |
# mean_per_class_error rmse mse | |
# 1 0.2821399 0.4424876 0.1957953 | |
# 2 0.2991276 0.4608900 0.2124196 | |
# 3 0.2978307 0.4442835 0.1973879 | |
# 4 0.2820679 0.4550833 0.2071008 | |
# 5 0.2977023 0.4534866 0.2056501 | |
# 6 0.2984797 0.4712888 0.2221132 | |
# | |
# [12 rows x 7 columns] | |
# Generate & view the extended leaderboard | |
lb <- h2o.get_leaderboard(aml, extra_columns = "ALL") | |
print(lb, n = nrow(lb)) | |
# model_id auc logloss aucpr | |
# 1 StackedEnsemble_AllModels_AutoML_20200428_121659 0.7664615 0.5741385 0.6647128 | |
# 2 GBM_3_AutoML_20200428_121659 0.7661455 0.6361831 0.6731820 | |
# 3 StackedEnsemble_BestOfFamily_AutoML_20200428_121659 0.7655231 0.5782747 0.6693175 | |
# 4 GBM_2_AutoML_20200428_121659 0.7631223 0.6182583 0.6575571 | |
# 5 GBM_1_AutoML_20200428_121659 0.7583139 0.6051435 0.6603360 | |
# 6 GBM_5_AutoML_20200428_121659 0.7524978 0.6816796 0.6454254 | |
# 7 XGBoost_3_AutoML_20200428_121659 0.7418445 0.7586254 0.6445601 | |
# 8 GBM_4_AutoML_20200428_121659 0.7323141 0.6852133 0.6328249 | |
# 9 XGBoost_2_AutoML_20200428_121659 0.7308456 1.0740440 0.6103080 | |
# 10 XGBoost_grid__1_AutoML_20200428_121659_model_2 0.7298379 1.4073864 0.6079640 | |
# 11 XGBoost_1_AutoML_20200428_121659 0.7161614 1.1279028 0.6015938 | |
# 12 XGBoost_grid__1_AutoML_20200428_121659_model_1 0.7150672 1.9156239 0.5879757 | |
# mean_per_class_error rmse mse training_time_ms predict_time_per_row_ms | |
# 1 0.2821399 0.4424876 0.1957953 14176 0.064844 | |
# 2 0.2991276 0.4608900 0.2124196 1762 0.010598 | |
# 3 0.2978307 0.4442835 0.1973879 3300 0.012200 | |
# 4 0.2820679 0.4550833 0.2071008 1778 0.010180 | |
# 5 0.2977023 0.4534866 0.2056501 1795 0.009838 | |
# 6 0.2984797 0.4712888 0.2221132 3073 0.012378 | |
# 7 0.3310299 0.4811705 0.2315250 3925 0.001627 | |
# 8 0.3227088 0.4764338 0.2269892 2679 0.011789 | |
# 9 0.3141142 0.5118053 0.2619447 7121 0.002939 | |
# 10 0.3021652 0.5203579 0.2707723 9047 0.003178 | |
# 11 0.3257608 0.5210415 0.2714843 6256 0.002832 | |
# 12 0.3352193 0.5557881 0.3089004 14026 0.002360 | |
# | |
# [12 rows x 9 columns] | |
# Partial depdendency plot of the leader model (a Stacked Ensemble) | |
# This shows that mean prediction is increasing across age | |
# View plot here: https://slack-files.com/T0329MHH6-F01221MJW5V-b96bd64b87 | |
h2o.partialPlot(aml@leader, prostate, cols = "AGE") | |
# You can also look at the first few rows | |
# View plot here: https://slack-files.com/T0329MHH6-F012P0M3LBE-b339948bd0 | |
rows <- 4 | |
par(mfrow = c(2, 2)) | |
for (i in 1:rows) { | |
h2o.partialPlot(aml@leader, prostate, cols = "AGE", row_index = i) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment