This script provides an example of using cross-validation to fine-tune parameters for learning a decision tree with scikit-learn.
A blog post about this code is available here, check it out!
- python -- developed with 2.7.6
- sckit-learn -- using version 0.16.1
- pandas -- using version 0.16.1
- numpy -- using version 1.9.2
and to create the graphic of the tree you must have graphviz/dot installed.
This provides an example of using the available functions-- look at lines 232 onwards to see how data is obtained and functions used.
$ python dt_cross_validation.py
The resulting output is:
-- get data:
-- iris.csv found locally
-- 10-fold cross-validation [using setup from previous post]
mean: 0.960 (std: 0.033)
-- Grid Parameter Search via 10-fold CV
GridSearchCV took 5.10 seconds for 288 candidate parameter settings.
Model with rank: 1
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 10, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': None, 'min_samples_leaf': 1}
Model with rank: 2
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 20, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': None, 'min_samples_leaf': 1}
Model with rank: 3
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 10, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': 5, 'min_samples_leaf': 1}
-- Best Parameters:
parameter: min_samples_split setting: 10
parameter: max_leaf_nodes setting: 5
parameter: criterion setting: gini
parameter: max_depth setting: None
parameter: min_samples_leaf setting: 1
-- Testing best parameters [Grid]...
mean: 0.967 (std: 0.033)
-- get_code for best parameters [Grid]:
if ( PetalLength <= 2.45000004768 ) {
return Iris-setosa ( 50 examples )
}
else {
if ( PetalWidth <= 1.75 ) {
if ( PetalLength <= 4.94999980927 ) {
if ( PetalWidth <= 1.65000009537 ) {
return Iris-versicolor ( 47 examples )
}
else {
return Iris-virginica ( 1 examples )
}
}
else {
return Iris-versicolor ( 2 examples )
return Iris-virginica ( 4 examples )
}
}
else {
return Iris-versicolor ( 1 examples )
return Iris-virginica ( 45 examples )
}
}
-- Random Parameter Search via 10-fold CV
RandomizedSearchCV took 1.55 seconds for 288 candidates parameter settings.
Model with rank: 1
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 14, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': 9, 'min_samples_leaf': 1}
Model with rank: 2
Mean validation score: 0.960 (std: 0.042)
Parameters: {'min_samples_split': 1, 'max_leaf_nodes': 11, 'criterion': 'gini', 'max_depth': 11, 'min_samples_leaf': 4}
Model with rank: 3
Mean validation score: 0.960 (std: 0.042)
Parameters: {'min_samples_split': 11, 'max_leaf_nodes': 4, 'criterion': 'gini', 'max_depth': 16, 'min_samples_leaf': 5}
-- Best Parameters:
parameters: min_samples_split setting: 14
parameters: max_leaf_nodes setting: 5
parameters: criterion setting: gini
parameters: max_depth setting: 9
parameters: min_samples_leaf setting: 1
-- Testing best parameters [Random]...
mean: 0.967 (std: 0.033)
-- get_code for best parameters [Random]:
if ( PetalLength <= 2.45000004768 ) {
return Iris-setosa ( 50 examples )
}
else {
if ( PetalWidth <= 1.75 ) {
if ( PetalLength <= 4.94999980927 ) {
if ( PetalWidth <= 1.65000009537 ) {
return Iris-versicolor ( 47 examples )
}
else {
return Iris-virginica ( 1 examples )
}
}
else {
return Iris-versicolor ( 2 examples )
return Iris-virginica ( 4 examples )
}
}
else {
return Iris-versicolor ( 1 examples )
return Iris-virginica ( 45 examples )
}
}