Created
July 10, 2019 12:23
-
-
Save risenW/a7621ab98e3998f483e6f62f47f5b471 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#get data for regression task | |
X_train, X_val, y_train, y_val = get_split_data(german_cred, target_name='age_yrs') | |
#Train and fit these models | |
ada_reg.fit(X_train, y_train) | |
gb_reg.fit(X_train, y_train) | |
#check their performance | |
print("MAE of AdaBoost is : ", get_mae(ada_reg.predict(X_val), y_val)) | |
print("MAE of Gradient Boosting is : ", get_mae(gb_reg.predict(X_val), y_val)) | |
#get data for regression task | |
X_train, X_val, y_train, y_val = get_split_data(german_cred, target_name='bad_credit') | |
#Train and fit these models | |
ada_cf.fit(X_train, y_train) | |
gb_cf.fit(X_train, y_train) | |
#check their performance | |
print("ACC of AdaBoost is : ", get_acc(ada_cf.predict(X_val), y_val)) | |
print("ACC of Gradient Boosting is : ", get_acc(gb_cf.predict(X_val), y_val)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment