Created
May 31, 2022 05:29
-
-
Save Eligijus112/b849d0a048f1d6bf97b94a2a3c79e74c to your computer and use it in GitHub Desktop.
Fitting a regression tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Train test spliting | |
| from sklearn.model_selection import train_test_split | |
| # Importing the sklearn implementation | |
| from sklearn.tree import DecisionTreeRegressor | |
| # Precision metrics | |
| from sklearn.metrics import mean_absolute_error | |
| # Spliting the data into training and testing sets | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0) | |
| # Defining the hyper parameters | |
| hps = { | |
| 'max_depth': 3, | |
| 'min_samples_split': 4 | |
| } | |
| # Loading the tree object | |
| tree = DecisionTreeRegressor(**hps) | |
| # Fitting on the training data | |
| tree.fit(X_train, y_train) | |
| # Predicting the test set | |
| y_pred = tree.predict(X_test) | |
| # Calculating the mean absolute error | |
| mae_train = mean_absolute_error(y_train, tree.predict(X_train)).round(2) | |
| mae_test = mean_absolute_error(y_test, y_pred).round(2) | |
| print(f"Mean absolute error on training set: {mae_train}") | |
| print(f"Mean absolute error on test set: {mae_test}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment